Add `ir_data_utils.reader`
This adds an `ir_data_utils.reader` wrapper that will be used to
access potentially unset chains of fields in an `ir_data` instance. For
now this is simply an annotation that marks sites that are accessing
missing fields.
Part of #118.
diff --git a/compiler/back_end/cpp/header_generator.py b/compiler/back_end/cpp/header_generator.py
index b8c9675..55c0a6e 100644
--- a/compiler/back_end/cpp/header_generator.py
+++ b/compiler/back_end/cpp/header_generator.py
@@ -28,6 +28,7 @@
from compiler.util import attribute_util
from compiler.util import error
from compiler.util import ir_data
+from compiler.util import ir_data_utils
from compiler.util import ir_util
from compiler.util import name_conversion
from compiler.util import resources
@@ -1475,7 +1476,7 @@
def _verify_namespace_attribute(attr, source_file_name, errors):
if attr.name.text != attributes.Attribute.NAMESPACE:
return
- namespace_value = attr.value.string_constant
+ namespace_value = ir_data_utils.reader(attr).value.string_constant
if not re.match(_NS_RE, namespace_value.text):
if re.match(_NS_EMPTY_RE, namespace_value.text):
errors.append([error.error(
diff --git a/compiler/front_end/attribute_checker.py b/compiler/front_end/attribute_checker.py
index 6735075..d8c637c 100644
--- a/compiler/front_end/attribute_checker.py
+++ b/compiler/front_end/attribute_checker.py
@@ -25,6 +25,7 @@
from compiler.util import attribute_util
from compiler.util import error
from compiler.util import ir_data
+from compiler.util import ir_data_utils
from compiler.util import ir_util
from compiler.util import traverse_ir
@@ -433,7 +434,7 @@
def _verify_back_end_attributes(attribute, expected_back_ends, source_file_name,
ir, errors):
- back_end_text = attribute.back_end.text
+ back_end_text = ir_data_utils.reader(attribute).back_end.text
if back_end_text not in expected_back_ends:
expected_back_ends_for_error = expected_back_ends - {""}
errors.append([error.error(
diff --git a/compiler/front_end/constraints.py b/compiler/front_end/constraints.py
index aa62add..fabcf04 100644
--- a/compiler/front_end/constraints.py
+++ b/compiler/front_end/constraints.py
@@ -17,6 +17,7 @@
from compiler.front_end import attributes
from compiler.util import error
from compiler.util import ir_data
+from compiler.util import ir_data_utils
from compiler.util import ir_util
from compiler.util import resources
from compiler.util import traverse_ir
@@ -52,7 +53,8 @@
"""Checks that inner array dimensions are constant."""
if type_ir.WhichOneof("size") == "automatic":
errors.append([error.error(
- source_file_name, type_ir.element_count.source_location,
+ source_file_name,
+ ir_data_utils.reader(type_ir).element_count.source_location,
"Array dimensions can only be omitted for the outermost dimension.")])
elif type_ir.WhichOneof("size") == "element_count":
if not ir_util.is_constant(type_ir.element_count):
diff --git a/compiler/front_end/expression_bounds.py b/compiler/front_end/expression_bounds.py
index cd6a851..e364399 100644
--- a/compiler/front_end/expression_bounds.py
+++ b/compiler/front_end/expression_bounds.py
@@ -634,7 +634,7 @@
def _compute_constraints_of_choice_operator(expression):
"""Computes the constraints of a choice operation '?:'."""
- condition, if_true, if_false = expression.function.args
+ condition, if_true, if_false = ir_data_utils.reader(expression).function.args
expression = ir_data_utils.builder(expression)
if condition.type.boolean.HasField("value"):
# The generated expressions for $size_in_bits and $size_in_bytes look like
diff --git a/compiler/front_end/synthetics.py b/compiler/front_end/synthetics.py
index ac79c47..d1bef9a 100644
--- a/compiler/front_end/synthetics.py
+++ b/compiler/front_end/synthetics.py
@@ -227,7 +227,7 @@
source_file_name, errors):
if not expression_ir.HasField("builtin_reference"):
return
- if expression_ir.builtin_reference.canonical_name.object_path[0] != "$next":
+ if ir_data_utils.reader(expression_ir).builtin_reference.canonical_name.object_path[0] != "$next":
return
if not last_location:
errors.append([
diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py
index d8a226a..9ad4aee 100644
--- a/compiler/front_end/type_check.py
+++ b/compiler/front_end/type_check.py
@@ -24,7 +24,7 @@
def _type_check_expression(expression, source_file_name, ir, errors):
"""Checks and annotates the type of an expression and all subexpressions."""
- if expression.type.WhichOneof("type"):
+ if ir_data_utils.reader(expression).type.WhichOneof("type"):
# This expression has already been type checked.
return
expression_variety = expression.WhichOneof("expression")
@@ -54,7 +54,7 @@
def _type_check(expression, source_file_name, errors, type_oneof, type_name,
expression_name):
- if expression.type.WhichOneof("type") != type_oneof:
+ if ir_data_utils.reader(expression).type.WhichOneof("type") != type_oneof:
errors.append([
error.error(source_file_name, expression.source_location,
"{} must be {}.".format(expression_name, type_name))
diff --git a/compiler/front_end/write_inference.py b/compiler/front_end/write_inference.py
index db555bc..7cfe7df 100644
--- a/compiler/front_end/write_inference.py
+++ b/compiler/front_end/write_inference.py
@@ -220,12 +220,13 @@
ir_data_utils.builder(field).write_method.physical = True
return
+ field_checker = ir_data_utils.reader(field)
field_builder = ir_data_utils.builder(field)
# A virtual field cannot be a direct alias if it has an additional
# requirement.
requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
- if (field.read_transform.WhichOneof("expression") != "field_reference" or
+ if (field_checker.read_transform.WhichOneof("expression") != "field_reference" or
requires_attr is not None):
inverse = _invert_expression(field.read_transform, ir)
if inverse:
diff --git a/compiler/util/attribute_util.py b/compiler/util/attribute_util.py
index 0b72084..063f38e 100644
--- a/compiler/util/attribute_util.py
+++ b/compiler/util/attribute_util.py
@@ -20,6 +20,7 @@
from compiler.util import error
from compiler.util import ir_data
+from compiler.util import ir_data_utils
from compiler.util import ir_util
from compiler.util import traverse_ir
@@ -30,7 +31,7 @@
def _attribute_name_for_errors(attr):
- if attr.back_end.text:
+ if ir_data_utils.reader(attr).back_end.text:
return f"({attr.back_end.text}) {attr.name.text}"
else:
return attr.name.text
@@ -98,7 +99,7 @@
def string_from_list(valid_values):
"""Checks if the given attr has one of the valid_values."""
def _string_from_list(attr, module_source_file):
- if attr.value.string_constant.text not in valid_values:
+ if ir_data_utils.reader(attr).value.string_constant.text not in valid_values:
return [[error.error(module_source_file,
attr.value.source_location,
"Attribute '{name}' must be '{options}'.".format(
@@ -252,15 +253,17 @@
errors = []
already_seen_attributes = {}
for attr in attribute_list:
- if attr.back_end.text:
+ field_checker = ir_data_utils.reader(attr)
+ if field_checker.back_end.text:
if attr.back_end.text != back_end:
continue
else:
if back_end is not None:
continue
attribute_name = _attribute_name_for_errors(attr)
- if (attr.name.text, attr.is_default) in already_seen_attributes:
- original_attr = already_seen_attributes[attr.name.text, attr.is_default]
+ attr_key = (field_checker.name.text, field_checker.is_default)
+ if attr_key in already_seen_attributes:
+ original_attr = already_seen_attributes[attr_key]
errors.append([
error.error(module_source_file,
attr.source_location,
@@ -269,9 +272,9 @@
original_attr.source_location,
"Original attribute")])
continue
- already_seen_attributes[attr.name.text, attr.is_default] = attr
+ already_seen_attributes[attr_key] = attr
- if (attr.name.text, attr.is_default) not in attribute_specs:
+ if attr_key not in attribute_specs:
if attr.is_default:
error_message = "Attribute '{}' may not be defaulted on {}.".format(
attribute_name, context_name)
diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py
index ed158b9..479d974 100644
--- a/compiler/util/ir_data_utils.py
+++ b/compiler/util/ir_data_utils.py
@@ -37,3 +37,11 @@
This is a no-op and just used for annotation for now.
"""
return target
+
+
+def reader(obj: ir_data.Message) -> ir_data.Message:
+ """A read-only wrapper for querying chains of IR data fields.
+
+ This is a no-op and just used for annotation for now.
+ """
+ return obj
diff --git a/compiler/util/ir_util.py b/compiler/util/ir_util.py
index f86ff74..02dc647 100644
--- a/compiler/util/ir_util.py
+++ b/compiler/util/ir_util.py
@@ -17,6 +17,7 @@
import operator
from compiler.util import ir_data
+from compiler.util import ir_data_utils
_FIXED_SIZE_ATTRIBUTE = "fixed_size_in_bits"
@@ -79,6 +80,7 @@
def is_constant_type(expression_type):
"""Returns True if expression_type is inhabited by a single value."""
+ expression_type = ir_data_utils.reader(expression_type)
return (expression_type.integer.modulus == "infinity" or
expression_type.boolean.HasField("value") or
expression_type.enumeration.HasField("value"))
@@ -86,6 +88,7 @@
def constant_value(expression, bindings=None):
"""Evaluates expression with the given bindings."""
+ expression = ir_data_utils.reader(expression)
if expression.WhichOneof("expression") == "constant":
return int(expression.constant.value)
elif expression.WhichOneof("expression") == "constant_reference":
@@ -252,7 +255,7 @@
if len(path) > 1:
return None
for parameter in type_definition.runtime_parameter:
- if parameter.name.name.text == path[0]:
+ if ir_data_utils.reader(parameter).name.name.text == path[0]:
return parameter
return None
@@ -389,4 +392,4 @@
"""Returns true if the field is read-only."""
# For now, all virtual fields are read-only, and no non-virtual fields are
# read-only.
- return field_ir.write_method.read_only
+ return ir_data_utils.reader(field_ir).write_method.read_only