Replace `WhichOneof("x")` with `which_x`. (#197)
This change refactors `OneOfField` so that all fields in a given `oneof`
construct share the same backing attributes on their container class --
`which_{oneof name}`, which holds the (string) name of the
currently-active member of the oneof named `{oneof name}` (or `None` if
no member is active), and `_value_{oneof name}`, which holds the value
of the currently-active member (or `None`).
This avoids looping through field specs in order to do an update or to
figure out which member of a `oneof` is currently active.
Since the `WhichOneof()` method is now a trivial read of a
similarly-named attribute, it can be inlined for a small decrease in
overall code size and without sacrificing readability.
As a result of these changes, the compiler now runs 4.5% faster on my
large test `.emb`.
diff --git a/compiler/back_end/cpp/header_generator.py b/compiler/back_end/cpp/header_generator.py
index 3b82db0..0201d12 100644
--- a/compiler/back_end/cpp/header_generator.py
+++ b/compiler/back_end/cpp/header_generator.py
@@ -592,19 +592,19 @@
def _cpp_basic_type_for_expression_type(expression_type, ir):
"""Returns the C++ basic type (int32_t, bool, etc.) for an ExpressionType."""
- if expression_type.WhichOneof("type") == "integer":
+ if expression_type.which_type == "integer":
return _cpp_integer_type_for_range(
int(expression_type.integer.minimum_value),
int(expression_type.integer.maximum_value),
)
- elif expression_type.WhichOneof("type") == "boolean":
+ elif expression_type.which_type == "boolean":
return "bool"
- elif expression_type.WhichOneof("type") == "enumeration":
+ elif expression_type.which_type == "enumeration":
return _get_fully_qualified_name(
expression_type.enumeration.name.canonical_name, ir
)
else:
- assert False, "Unknown expression type " + expression_type.WhichOneof("type")
+ assert False, "Unknown expression type " + expression_type.which_type
def _cpp_basic_type_for_expression(expression, ir):
@@ -667,12 +667,12 @@
enum_types = set()
have_boolean_types = False
for subexpression in [expression] + list(args):
- if subexpression.type.WhichOneof("type") == "integer":
+ if subexpression.type.which_type == "integer":
minimum_integers.append(int(subexpression.type.integer.minimum_value))
maximum_integers.append(int(subexpression.type.integer.maximum_value))
- elif subexpression.type.WhichOneof("type") == "enumeration":
+ elif subexpression.type.which_type == "enumeration":
enum_types.add(_cpp_basic_type_for_expression(subexpression, ir))
- elif subexpression.type.WhichOneof("type") == "boolean":
+ elif subexpression.type.which_type == "boolean":
have_boolean_types = True
# At present, all Emboss functions other than `$has` take and return one of
# the following:
@@ -820,7 +820,7 @@
# will fit into C++ types, or that operator arguments and return types can fit
# in the same type: expressions like `-0x8000_0000_0000_0000` and
# `0x1_0000_0000_0000_0000 - 1` can appear.
- if expression.type.WhichOneof("type") == "integer":
+ if expression.type.which_type == "integer":
if expression.type.integer.modulus == "infinity":
return _ExpressionResult(
_render_integer_for_expression(
@@ -828,31 +828,29 @@
),
True,
)
- elif expression.type.WhichOneof("type") == "boolean":
+ elif expression.type.which_type == "boolean":
if expression.type.boolean.HasField("value"):
if expression.type.boolean.value:
return _ExpressionResult(_maybe_type("bool") + "(true)", True)
else:
return _ExpressionResult(_maybe_type("bool") + "(false)", True)
- elif expression.type.WhichOneof("type") == "enumeration":
+ elif expression.type.which_type == "enumeration":
if expression.type.enumeration.HasField("value"):
return _ExpressionResult(
_render_enum_value(expression.type.enumeration, ir), True
)
else:
# There shouldn't be any "opaque" type expressions here.
- assert False, "Unhandled expression type {}".format(
- expression.type.WhichOneof("type")
- )
+ assert False, "Unhandled expression type {}".format(expression.type.which_type)
result = None
# Otherwise, render the operation.
- if expression.WhichOneof("expression") == "function":
+ if expression.which_expression == "function":
result = _render_builtin_operation(expression, ir, field_reader, subexpressions)
- elif expression.WhichOneof("expression") == "field_reference":
+ elif expression.which_expression == "field_reference":
result = field_reader.render_field(expression, ir, subexpressions)
elif (
- expression.WhichOneof("expression") == "builtin_reference"
+ expression.which_expression == "builtin_reference"
and expression.builtin_reference.canonical_name.object_path[-1]
== "$logical_value"
):
@@ -982,7 +980,7 @@
definitions should be placed after the class definition. These are
separated to satisfy C++'s declaration-before-use requirements.
"""
- if field_ir.write_method.WhichOneof("method") == "alias":
+ if field_ir.write_method.which_method == "alias":
return _generate_field_indirection(field_ir, enclosing_type_name, ir)
read_subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_")
@@ -1011,7 +1009,7 @@
_TEMPLATES.structure_single_virtual_field_method_definitions
)
- if field_ir.write_method.WhichOneof("method") == "transform":
+ if field_ir.write_method.which_method == "transform":
destination = _render_variable(
ir_util.hashable_form_of_field_reference(
field_ir.write_method.transform.destination
@@ -1042,15 +1040,15 @@
assert logical_type, "Could not find appropriate C++ type for {}".format(
field_ir.read_transform
)
- if field_ir.read_transform.type.WhichOneof("type") == "integer":
+ if field_ir.read_transform.type.which_type == "integer":
write_to_text_stream_function = "WriteIntegerViewToTextStream"
- elif field_ir.read_transform.type.WhichOneof("type") == "boolean":
+ elif field_ir.read_transform.type.which_type == "boolean":
write_to_text_stream_function = "WriteBooleanViewToTextStream"
- elif field_ir.read_transform.type.WhichOneof("type") == "enumeration":
+ elif field_ir.read_transform.type.which_type == "enumeration":
write_to_text_stream_function = "WriteEnumViewToTextStream"
else:
assert False, "Unexpected read-only virtual field type {}".format(
- field_ir.read_transform.type.WhichOneof("type")
+ field_ir.read_transform.type.which_type
)
value_is_ok = _generate_validator_expression_for(field_ir, ir)
diff --git a/compiler/front_end/attribute_checker.py b/compiler/front_end/attribute_checker.py
index 232eee8..9920db4 100644
--- a/compiler/front_end/attribute_checker.py
+++ b/compiler/front_end/attribute_checker.py
@@ -441,7 +441,7 @@
field_expression_type = type_check.unbounded_expression_type_for_physical_type(
field_type
)
- if field_expression_type.WhichOneof("type") not in (
+ if field_expression_type.which_type not in (
"integer",
"enumeration",
"boolean",
diff --git a/compiler/front_end/constraints.py b/compiler/front_end/constraints.py
index 852c9ad..44b884d 100644
--- a/compiler/front_end/constraints.py
+++ b/compiler/front_end/constraints.py
@@ -51,7 +51,7 @@
def _check_that_inner_array_dimensions_are_constant(type_ir, source_file_name, errors):
"""Checks that inner array dimensions are constant."""
- if type_ir.WhichOneof("size") == "automatic":
+ if type_ir.which_size == "automatic":
errors.append(
[
error.error(
@@ -61,7 +61,7 @@
)
]
)
- elif type_ir.WhichOneof("size") == "element_count":
+ elif type_ir.which_size == "element_count":
if not ir_util.is_constant(type_ir.element_count):
errors.append(
[
@@ -140,7 +140,7 @@
def _check_constancy_of_constant_references(expression, source_file_name, errors, ir):
"""Checks that constant_references are constant."""
- if expression.WhichOneof("expression") != "constant_reference":
+ if expression.which_expression != "constant_reference":
return
# This is a bit of a hack: really, we want to know that the referred-to object
# has no dependencies on any instance variables of its parent structure; i.e.,
@@ -326,7 +326,7 @@
physical_type = runtime_parameter.physical_type_alias
logical_type = runtime_parameter.type
size = ir_util.constant_value(physical_type.size_in_bits)
- if logical_type.WhichOneof("type") == "integer":
+ if logical_type.which_type == "integer":
integer_errors = _integer_bounds_errors(
logical_type.integer,
"parameter",
@@ -345,7 +345,7 @@
source_file_name,
)
)
- elif logical_type.WhichOneof("type") == "enumeration":
+ elif logical_type.which_type == "enumeration":
if physical_type.HasField("size_in_bits"):
# This seems a little weird: for `UInt`, `Int`, etc., the explicit size is
# required, but for enums it is banned. This is because enums have a
@@ -567,9 +567,9 @@
def _integer_bounds_errors_for_expression(expression, source_file_name):
"""Checks that `expression` is in range for int64_t or uint64_t."""
# Only check non-constant subexpressions.
- if expression.WhichOneof(
- "expression"
- ) == "function" and not ir_util.is_constant_type(expression.type):
+ if expression.which_expression == "function" and not ir_util.is_constant_type(
+ expression.type
+ ):
errors = []
for arg in expression.function.args:
errors += _integer_bounds_errors_for_expression(arg, source_file_name)
@@ -577,7 +577,7 @@
# Don't cascade bounds errors: report them at the lowest level they
# appear.
return errors
- if expression.type.WhichOneof("type") == "integer":
+ if expression.type.which_type == "integer":
errors = _integer_bounds_errors(
expression.type.integer,
"expression",
@@ -586,13 +586,13 @@
)
if errors:
return errors
- if expression.WhichOneof(
- "expression"
- ) == "function" and not ir_util.is_constant_type(expression.type):
+ if expression.which_expression == "function" and not ir_util.is_constant_type(
+ expression.type
+ ):
int64_only_clauses = []
uint64_only_clauses = []
for clause in [expression] + list(expression.function.args):
- if clause.type.WhichOneof("type") == "integer":
+ if clause.type.which_type == "integer":
arg_minimum = int(clause.type.integer.minimum_value)
arg_maximum = int(clause.type.integer.maximum_value)
if not _bounds_can_fit_64_bit_signed(arg_minimum, arg_maximum):
diff --git a/compiler/front_end/expression_bounds.py b/compiler/front_end/expression_bounds.py
index cca36ee..51efca5 100644
--- a/compiler/front_end/expression_bounds.py
+++ b/compiler/front_end/expression_bounds.py
@@ -36,7 +36,7 @@
"""Adds appropriate bounding constraints to the given expression."""
if ir_util.is_constant_type(expression.type):
return
- expression_variety = expression.WhichOneof("expression")
+ expression_variety = expression.which_expression
if expression_variety == "constant":
_compute_constant_value_of_constant(expression)
elif expression_variety == "constant_reference":
@@ -51,7 +51,7 @@
_compute_constant_value_of_boolean_constant(expression)
else:
assert False, "Unknown expression variety {!r}".format(expression_variety)
- if expression.type.WhichOneof("type") == "integer":
+ if expression.type.which_type == "integer":
_assert_integer_constraints(expression)
@@ -138,7 +138,7 @@
ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
return
# Non-virtual non-integer fields do not (yet) have constraints.
- if expression.type.WhichOneof("type") == "integer":
+ if expression.type.which_type == "integer":
# TODO(bolms): These lines will need to change when support is added for
# fixed-point types.
expression.type.integer.modulus = "1"
@@ -205,7 +205,7 @@
def _compute_constraints_of_parameter(parameter):
- if parameter.type.WhichOneof("type") == "integer":
+ if parameter.type.which_type == "integer":
type_size = ir_util.constant_value(parameter.physical_type_alias.size_in_bits)
_set_integer_constraints_from_physical_type(
parameter, parameter.physical_type_alias, type_size
@@ -238,14 +238,14 @@
# [requires] attribute, are elevated to write-through fields, so that the
# [requires] clause can be checked in Write, CouldWriteValue, TryToWrite,
# Read, and Ok.
- if expression.type.WhichOneof("type") == "integer":
+ if expression.type.which_type == "integer":
assert expression.type.integer.modulus
assert expression.type.integer.modular_value
assert expression.type.integer.minimum_value
assert expression.type.integer.maximum_value
- elif expression.type.WhichOneof("type") == "enumeration":
+ elif expression.type.which_type == "enumeration":
assert expression.type.enumeration.name
- elif expression.type.WhichOneof("type") == "boolean":
+ elif expression.type.which_type == "boolean":
pass
else:
assert False, "Unexpected type for $logical_value"
@@ -559,9 +559,9 @@
def _compute_constraints_of_maximum_function(expression):
"""Computes the constraints of the $max function."""
- assert expression.type.WhichOneof("type") == "integer"
+ assert expression.type.which_type == "integer"
args = expression.function.args
- assert args[0].type.WhichOneof("type") == "integer"
+ assert args[0].type.which_type == "integer"
# The minimum value of the result occurs when every argument takes its minimum
# value, which means that the minimum result is the maximum-of-minimums.
expression.type.integer.minimum_value = str(
@@ -683,7 +683,7 @@
# constraints.check_constraints() will complain if minimum and maximum are not
# set correctly. I'm (bolms@) not sure if the modulus/modular_value pulls its
# weight, but for completeness I've left it in.
- if if_true.type.WhichOneof("type") == "integer":
+ if if_true.type.which_type == "integer":
# The minimum value of the choice is the minimum value of either side, and
# the maximum is the maximum value of either side.
expression.type.integer.minimum_value = str(
@@ -709,10 +709,10 @@
expression.type.integer.modulus = str(new_modulus)
expression.type.integer.modular_value = str(new_modular_value)
else:
- assert if_true.type.WhichOneof("type") in (
+ assert if_true.type.which_type in (
"boolean",
"enumeration",
- ), "Unknown type {} for expression".format(if_true.type.WhichOneof("type"))
+ ), "Unknown type {} for expression".format(if_true.type.which_type)
def _greatest_common_divisor(a, b):
diff --git a/compiler/front_end/expression_bounds_test.py b/compiler/front_end/expression_bounds_test.py
index 7af6836..e5bc25d 100644
--- a/compiler/front_end/expression_bounds_test.py
+++ b/compiler/front_end/expression_bounds_test.py
@@ -1007,7 +1007,7 @@
)
self.assertEqual([], expression_bounds.compute_constants(ir))
expr = ir.module[0].type[0].structure.field[1].existence_condition
- self.assertEqual("boolean", expr.type.WhichOneof("type"))
+ self.assertEqual("boolean", expr.type.which_type)
self.assertFalse(expr.type.boolean.HasField("value"))
def test_uint_value_range_for_explicit_size(self):
diff --git a/compiler/front_end/symbol_resolver.py b/compiler/front_end/symbol_resolver.py
index 498b1a9..0590b0d 100644
--- a/compiler/front_end/symbol_resolver.py
+++ b/compiler/front_end/symbol_resolver.py
@@ -459,10 +459,7 @@
previous_reference = field_reference.path[0]
for ref in field_reference.path[1:]:
while ir_util.field_is_virtual(previous_field):
- if (
- previous_field.read_transform.WhichOneof("expression")
- == "field_reference"
- ):
+ if previous_field.read_transform.which_expression == "field_reference":
# Pass a separate error list into the recursive _resolve_field_reference
# call so that only one copy of the error for a particular reference
# will actually surface: in particular, the one that results from a
@@ -494,7 +491,7 @@
)
)
return
- if previous_field.type.WhichOneof("type") == "array_type":
+ if previous_field.type.which_type == "array_type":
errors.append(
array_subfield_error(
source_file_name,
@@ -503,7 +500,7 @@
)
)
return
- assert previous_field.type.WhichOneof("type") == "atomic_type"
+ assert previous_field.type.which_type == "atomic_type"
member_name = ir_data_utils.copy(
previous_field.type.atomic_type.reference.canonical_name
)
diff --git a/compiler/front_end/symbol_resolver_test.py b/compiler/front_end/symbol_resolver_test.py
index 693d157..ddf7783 100644
--- a/compiler/front_end/symbol_resolver_test.py
+++ b/compiler/front_end/symbol_resolver_test.py
@@ -179,7 +179,7 @@
struct_ir = ir.module[0].type[4].structure
array_type = struct_ir.field[0].type.array_type
# The symbol resolver should ignore void fields.
- self.assertEqual("automatic", array_type.WhichOneof("size"))
+ self.assertEqual("automatic", array_type.which_size)
def test_name_definitions_have_correct_canonical_names(self):
ir = self._construct_ir(_HAPPY_EMB)
diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py
index c3f7870..ce4155d 100644
--- a/compiler/front_end/type_check.py
+++ b/compiler/front_end/type_check.py
@@ -24,10 +24,10 @@
def _type_check_expression(expression, source_file_name, ir, errors):
"""Checks and annotates the type of an expression and all subexpressions."""
- if ir_data_utils.reader(expression).type.WhichOneof("type"):
+ if ir_data_utils.reader(expression).type.which_type:
# This expression has already been type checked.
return
- expression_variety = expression.WhichOneof("expression")
+ expression_variety = expression.which_expression
if expression_variety == "constant":
_type_check_integer_constant(expression)
elif expression_variety == "constant_reference":
@@ -55,7 +55,7 @@
def _type_check(
expression, source_file_name, errors, type_oneof, type_name, expression_name
):
- if ir_data_utils.reader(expression).type.WhichOneof("type") != type_oneof:
+ if ir_data_utils.reader(expression).type.which_type != type_oneof:
errors.append(
[
error.error(
@@ -80,7 +80,7 @@
def _kind_check_field_reference(expression, source_file_name, errors, expression_name):
- if expression.WhichOneof("expression") != "field_reference":
+ if expression.which_expression != "field_reference":
errors.append(
[
error.error(
@@ -335,7 +335,7 @@
def _annotate_parameter_type(parameter, ir, source_file_name, errors):
- if parameter.physical_type_alias.WhichOneof("type") != "atomic_type":
+ if parameter.physical_type_alias.which_type != "atomic_type":
errors.append(
[
error.error(
@@ -353,13 +353,13 @@
def _types_are_compatible(a, b):
"""Returns true if a and b have compatible types."""
- if a.type.WhichOneof("type") != b.type.WhichOneof("type"):
+ if a.type.which_type != b.type.which_type:
return False
- elif a.type.WhichOneof("type") == "enumeration":
+ elif a.type.which_type == "enumeration":
return ir_util.hashable_form_of_reference(
a.type.enumeration.name
) == ir_util.hashable_form_of_reference(b.type.enumeration.name)
- elif a.type.WhichOneof("type") in ("integer", "boolean"):
+ elif a.type.which_type in ("integer", "boolean"):
# All integers are compatible with integers; booleans are compatible with
# booleans
return True
@@ -383,7 +383,7 @@
left = expression.function.args[0]
right = expression.function.args[1]
for argument, name in ((left, "Left"), (right, "Right")):
- if argument.type.WhichOneof("type") not in acceptable_types:
+ if argument.type.which_type not in acceptable_types:
errors.append(
[
error.error(
@@ -415,7 +415,7 @@
def _type_check_choice_operator(expression, source_file_name, errors):
"""Checks the type of the choice operator cond ? if_true : if_false."""
condition = expression.function.args[0]
- if condition.type.WhichOneof("type") != "boolean":
+ if condition.type.which_type != "boolean":
errors.append(
[
error.error(
@@ -426,7 +426,7 @@
]
)
if_true = expression.function.args[1]
- if if_true.type.WhichOneof("type") not in ("integer", "boolean", "enumeration"):
+ if if_true.type.which_type not in ("integer", "boolean", "enumeration"):
errors.append(
[
error.error(
@@ -450,11 +450,11 @@
)
]
)
- if if_true.type.WhichOneof("type") == "integer":
+ if if_true.type.which_type == "integer":
_annotate_as_integer(expression)
- elif if_true.type.WhichOneof("type") == "boolean":
+ elif if_true.type.which_type == "boolean":
_annotate_as_boolean(expression)
- elif if_true.type.WhichOneof("type") == "enumeration":
+ elif if_true.type.which_type == "enumeration":
ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(
if_true.type.enumeration.name
)
@@ -492,9 +492,9 @@
def _type_name_for_error_messages(expression_type):
- if expression_type.WhichOneof("type") == "integer":
+ if expression_type.which_type == "integer":
return "integer"
- elif expression_type.WhichOneof("type") == "enumeration":
+ elif expression_type.which_type == "enumeration":
# TODO(bolms): Should this be the fully-qualified name?
return expression_type.enumeration.name.canonical_name.object_path[-1]
assert False, "Shouldn't be here."
@@ -526,7 +526,7 @@
)
return
for i in range(len(referenced_type.runtime_parameter)):
- if referenced_type.runtime_parameter[i].type.WhichOneof("type") not in (
+ if referenced_type.runtime_parameter[i].type.which_type not in (
"integer",
"boolean",
"enumeration",
@@ -535,9 +535,10 @@
# definition site; no need for another, probably-confusing error at any
# usage sites.
continue
- if atomic_type.runtime_parameter[i].type.WhichOneof(
- "type"
- ) != referenced_type.runtime_parameter[i].type.WhichOneof("type"):
+ if (
+ atomic_type.runtime_parameter[i].type.which_type
+ != referenced_type.runtime_parameter[i].type.which_type
+ ):
errors.append(
[
error.error(
@@ -565,7 +566,7 @@
def _type_check_parameter(runtime_parameter, source_file_name, errors):
"""Checks the type of a parameter to a physical type."""
- if runtime_parameter.type.WhichOneof("type") not in ("integer", "enumeration"):
+ if runtime_parameter.type.which_type not in ("integer", "enumeration"):
errors.append(
[
error.error(
diff --git a/compiler/front_end/type_check_test.py b/compiler/front_end/type_check_test.py
index 995f20d..766f9e8 100644
--- a/compiler/front_end/type_check_test.py
+++ b/compiler/front_end/type_check_test.py
@@ -39,7 +39,7 @@
)
self.assertEqual([], type_check.annotate_types(ir))
expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "integer")
+ self.assertEqual(expression.type.which_type, "integer")
def test_adds_boolean_constant_type(self):
ir = self._make_ir(
@@ -51,7 +51,7 @@
ir_data_utils.IrDataSerializer(ir).to_json(indent=2),
)
expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "boolean")
+ self.assertEqual(expression.type.which_type, "boolean")
def test_adds_enum_constant_type(self):
ir = self._make_ir(
@@ -62,7 +62,7 @@
)
self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
expression = ir.module[0].type[0].structure.field[0].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "enumeration")
+ self.assertEqual(expression.type.which_type, "enumeration")
enum_type_name = expression.type.enumeration.name.canonical_name
self.assertEqual(enum_type_name.module_file, "m.emb")
self.assertEqual(enum_type_name.object_path[0], "Enum")
@@ -77,7 +77,7 @@
)
self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "enumeration")
+ self.assertEqual(expression.type.which_type, "enumeration")
enum_type_name = expression.type.enumeration.name.canonical_name
self.assertEqual(enum_type_name.module_file, "m.emb")
self.assertEqual(enum_type_name.object_path[0], "Enum")
@@ -88,9 +88,9 @@
)
self.assertEqual([], type_check.annotate_types(ir))
expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "integer")
- self.assertEqual(expression.function.args[0].type.WhichOneof("type"), "integer")
- self.assertEqual(expression.function.args[1].type.WhichOneof("type"), "integer")
+ self.assertEqual(expression.type.which_type, "integer")
+ self.assertEqual(expression.function.args[0].type.which_type, "integer")
+ self.assertEqual(expression.function.args[1].type.which_type, "integer")
def test_adds_enum_operation_type(self):
ir = self._make_ir(
@@ -102,13 +102,9 @@
)
self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "boolean")
- self.assertEqual(
- expression.function.args[0].type.WhichOneof("type"), "enumeration"
- )
- self.assertEqual(
- expression.function.args[1].type.WhichOneof("type"), "enumeration"
- )
+ self.assertEqual(expression.type.which_type, "boolean")
+ self.assertEqual(expression.function.args[0].type.which_type, "enumeration")
+ self.assertEqual(expression.function.args[1].type.which_type, "enumeration")
def test_adds_enum_comparison_operation_type(self):
ir = self._make_ir(
@@ -120,13 +116,9 @@
)
self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "boolean")
- self.assertEqual(
- expression.function.args[0].type.WhichOneof("type"), "enumeration"
- )
- self.assertEqual(
- expression.function.args[1].type.WhichOneof("type"), "enumeration"
- )
+ self.assertEqual(expression.type.which_type, "boolean")
+ self.assertEqual(expression.function.args[0].type.which_type, "enumeration")
+ self.assertEqual(expression.function.args[1].type.which_type, "enumeration")
def test_adds_integer_field_type(self):
ir = self._make_ir(
@@ -134,7 +126,7 @@
)
self.assertEqual([], type_check.annotate_types(ir))
expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "integer")
+ self.assertEqual(expression.type.which_type, "integer")
def test_adds_opaque_field_type(self):
ir = self._make_ir(
@@ -146,7 +138,7 @@
)
self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "opaque")
+ self.assertEqual(expression.type.which_type, "opaque")
def test_adds_opaque_field_type_for_array(self):
ir = self._make_ir(
@@ -154,7 +146,7 @@
)
self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "opaque")
+ self.assertEqual(expression.type.which_type, "opaque")
def test_error_on_bad_plus_operand_types(self):
ir = self._make_ir(
@@ -395,7 +387,7 @@
)
expression = ir.module[0].type[0].structure.field[1].location.size
self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
- self.assertEqual("boolean", expression.type.WhichOneof("type"))
+ self.assertEqual("boolean", expression.type.which_type)
def test_choice_of_integers(self):
ir = self._make_ir(
@@ -405,7 +397,7 @@
)
expression = ir.module[0].type[0].structure.field[1].location.size
self.assertEqual([], type_check.annotate_types(ir))
- self.assertEqual("integer", expression.type.WhichOneof("type"))
+ self.assertEqual("integer", expression.type.which_type)
def test_choice_of_enums(self):
ir = self._make_ir(
@@ -417,7 +409,7 @@
)
expression = ir.module[0].type[0].structure.field[1].location.size
self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
- self.assertEqual("enumeration", expression.type.WhichOneof("type"))
+ self.assertEqual("enumeration", expression.type.which_type)
self.assertFalse(expression.type.enumeration.HasField("value"))
self.assertEqual(
"m.emb", expression.type.enumeration.name.canonical_name.module_file
@@ -579,7 +571,7 @@
ir = self._make_ir("struct Foo:\n" " $max(1, 2, 3) [+1] UInt:8[] x\n")
expression = ir.module[0].type[0].structure.field[0].location.start
self.assertEqual([], type_check.annotate_types(ir))
- self.assertEqual("integer", expression.type.WhichOneof("type"))
+ self.assertEqual("integer", expression.type.which_type)
def test_error_on_bad_max_argument(self):
ir = self._make_ir(
@@ -622,7 +614,7 @@
ir = self._make_ir("struct Foo:\n" " $upper_bound(3) [+1] UInt:8[] x\n")
expression = ir.module[0].type[0].structure.field[0].location.start
self.assertEqual([], type_check.annotate_types(ir))
- self.assertEqual("integer", expression.type.WhichOneof("type"))
+ self.assertEqual("integer", expression.type.which_type)
def test_upper_bound_too_few_arguments(self):
ir = self._make_ir("struct Foo:\n" " $upper_bound() [+1] UInt:8[] x\n")
@@ -681,7 +673,7 @@
ir = self._make_ir("struct Foo:\n" " $lower_bound(3) [+1] UInt:8[] x\n")
expression = ir.module[0].type[0].structure.field[0].location.start
self.assertEqual([], type_check.annotate_types(ir))
- self.assertEqual("integer", expression.type.WhichOneof("type"))
+ self.assertEqual("integer", expression.type.which_type)
def test_lower_bound_too_few_arguments(self):
ir = self._make_ir("struct Foo:\n" " $lower_bound() [+1] UInt:8[] x\n")
diff --git a/compiler/front_end/write_inference.py b/compiler/front_end/write_inference.py
index 8353306..bd4e2ba 100644
--- a/compiler/front_end/write_inference.py
+++ b/compiler/front_end/write_inference.py
@@ -51,9 +51,9 @@
def _recursively_find_field_reference_path(expression):
"""Recursive implementation of _find_field_reference_path."""
- if expression.WhichOneof("expression") == "field_reference":
+ if expression.which_expression == "field_reference":
return 1, []
- elif expression.WhichOneof("expression") == "function":
+ elif expression.which_expression == "function":
field_count = 0
path = []
for index in range(len(expression.function.args)):
@@ -228,7 +228,7 @@
# requirement.
requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
if (
- field_checker.read_transform.WhichOneof("expression") != "field_reference"
+ field_checker.read_transform.which_expression != "field_reference"
or requires_attr is not None
):
inverse = _invert_expression(field.read_transform, ir)
diff --git a/compiler/front_end/write_inference_test.py b/compiler/front_end/write_inference_test.py
index c6afa2f..f4ea8bf 100644
--- a/compiler/front_end/write_inference_test.py
+++ b/compiler/front_end/write_inference_test.py
@@ -185,7 +185,7 @@
ir = self._make_ir("struct Foo(x: UInt:8):\n" " let y = 50 + x\n")
self.assertEqual([], write_inference.set_write_methods(ir))
field = ir.module[0].type[0].structure.field[0]
- self.assertEqual("read_only", field.write_method.WhichOneof("method"))
+ self.assertEqual("read_only", field.write_method.which_method)
def test_adds_transform_write_method_with_complex_auxiliary_subexpression(self):
ir = self._make_ir(
diff --git a/compiler/util/attribute_util.py b/compiler/util/attribute_util.py
index 2ebbffa..d20f191 100644
--- a/compiler/util/attribute_util.py
+++ b/compiler/util/attribute_util.py
@@ -57,7 +57,7 @@
def _is_boolean(attr, module_source_file):
"""Checks if the given attr is a boolean."""
- if attr.value.expression.type.WhichOneof("type") != "boolean":
+ if attr.value.expression.type.which_type != "boolean":
return [
[
error.error(
@@ -76,7 +76,7 @@
"""Checks if the given attr is an integer constant expression."""
if (
not attr.value.HasField("expression")
- or attr.value.expression.type.WhichOneof("type") != "integer"
+ or attr.value.expression.type.which_type != "integer"
):
return [
[
diff --git a/compiler/util/ir_data.py b/compiler/util/ir_data.py
index fda124e..cd12b96 100644
--- a/compiler/util/ir_data.py
+++ b/compiler/util/ir_data.py
@@ -106,20 +106,6 @@
"""Indicates if this class has the given field defined and it is set."""
return getattr(self, name, None) is not None
- # Non-PEP8 name to mimic the Google Protobuf interface.
- def WhichOneof(self, oneof_name): # pylint:disable=invalid-name
- """Indicates which field has been set for the oneof value.
-
- Args:
- oneof_name: the name of the oneof construct to test.
-
- Returns: the field name, or None if no field has been set.
- """
- for field_name, oneof in self.field_specs.oneof_mappings:
- if oneof == oneof_name and self.HasField(field_name):
- return field_name
- return None
-
################################################################################
# From here to the end of the file are actual structure definitions.
diff --git a/compiler/util/ir_data_fields.py b/compiler/util/ir_data_fields.py
index 76df36c..e6b6ddb 100644
--- a/compiler/util/ir_data_fields.py
+++ b/compiler/util/ir_data_fields.py
@@ -175,11 +175,7 @@
self.all_field_specs = specs
self.field_specs = tuple(specs.values())
self.dataclass_field_specs = {k: v for k, v in specs.items() if v.is_dataclass}
- self.oneof_field_specs = {k: v for k, v in specs.items() if v.is_oneof}
self.sequence_field_specs = tuple(v for v in specs.values() if v.is_sequence)
- self.oneof_mappings = tuple(
- (k, v.oneof) for k, v in self.oneof_field_specs.items() if v.oneof
- )
def all_ir_classes(mod):
@@ -399,18 +395,23 @@
super().__init__()
self.oneof = oneof
self.owner_type = None
- self.proxy_name: str = ""
+ self.proxy_name: str = f"_value_{oneof}"
+ self.proxy_choice_name: str = f"which_{oneof}"
self.name: str = ""
def __set_name__(self, owner, name):
self.name = name
- self.proxy_name = f"_{name}"
self.owner_type = owner
- # Add our empty proxy field to the class.
+ # Add the empty proxy fields to the class. This may re-initialize
+ # these if another field in this oneof got there first.
setattr(owner, self.proxy_name, None)
+ setattr(owner, self.proxy_choice_name, None)
def __get__(self, obj, objtype=None):
- return getattr(obj, self.proxy_name)
+ if getattr(obj, self.proxy_choice_name, None) == self.name:
+ return getattr(obj, self.proxy_name)
+ else:
+ return None
def __set__(self, obj, value):
if value is self:
@@ -418,15 +419,13 @@
# default to None.
value = None
- if value is not None:
- # Clear the others
- for name, oneof in IrDataclassSpecs.get_specs(
- self.owner_type
- ).oneof_mappings:
- if oneof == self.oneof and name != self.name:
- setattr(obj, name, None)
-
- setattr(obj, self.proxy_name, value)
+ if value is None:
+ if getattr(obj, self.proxy_choice_name) == self.name:
+ setattr(obj, self.proxy_name, None)
+ setattr(obj, self.proxy_choice_name, None)
+ else:
+ setattr(obj, self.proxy_name, value)
+ setattr(obj, self.proxy_choice_name, self.name)
def oneof_field(name: str):
diff --git a/compiler/util/ir_data_fields_test.py b/compiler/util/ir_data_fields_test.py
index f6ee2ca..0376b69 100644
--- a/compiler/util/ir_data_fields_test.py
+++ b/compiler/util/ir_data_fields_test.py
@@ -162,29 +162,6 @@
self.assertEqual(one_of_field_test.int_field_2, 200)
self.assertFalse(one_of_field_test.normal_field)
- def test_oneof_specs(self):
- """Tests the `oneof_field_specs` filter."""
- expected = {
- "int_field_1": ir_data_fields.make_field_spec(
- "int_field_1", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1"
- ),
- "int_field_2": ir_data_fields.make_field_spec(
- "int_field_2", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1"
- ),
- }
- actual = ir_data_fields.IrDataclassSpecs.get_specs(
- OneofFieldTest
- ).oneof_field_specs
- self.assertDictEqual(actual, expected)
-
- def test_oneof_mappings(self):
- """Tests the `oneof_mappings` function."""
- expected = (("int_field_1", "type_1"), ("int_field_2", "type_1"))
- actual = ir_data_fields.IrDataclassSpecs.get_specs(
- OneofFieldTest
- ).oneof_mappings
- self.assertTupleEqual(actual, expected)
-
class IrDataFieldsTest(unittest.TestCase):
"""Tests misc methods in ir_data_fields."""
diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py
index 99fbf7d..e7679a6 100644
--- a/compiler/util/ir_data_utils.py
+++ b/compiler/util/ir_data_utils.py
@@ -277,7 +277,7 @@
if ir is None:
return object.__getattribute__(self, name)
- if name in ("HasField", "WhichOneof"):
+ if name in ("HasField",):
return getattr(ir, name)
field_spec = field_specs(ir).get(name)
@@ -362,8 +362,13 @@
if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
if name == "HasField":
return lambda x: False
- if name == "WhichOneof":
- return lambda x: None
+ # This *should* be limited to only the `which_` attributes that
+ # correspond to real oneofs, but that would add complexity and
+ # runtime, and the odds are low that laxness here causes a bug
+ # -- the same code needs to run against real IR objects that
+ # will raise if a nonexistent `which_` field is accessed.
+ if name.startswith("which_"):
+ return None
return object.__getattribute__(ir_or_spec, name)
if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
diff --git a/compiler/util/ir_util.py b/compiler/util/ir_util.py
index 2d8763b..c1de244 100644
--- a/compiler/util/ir_util.py
+++ b/compiler/util/ir_util.py
@@ -72,7 +72,7 @@
attribute_value = get_attribute(attribute_list, name)
if (
not attribute_value
- or attribute_value.expression.type.WhichOneof("type") != "integer"
+ or attribute_value.expression.type.which_type != "integer"
or not is_constant(attribute_value.expression)
):
return default_value
@@ -98,42 +98,42 @@
if expression is None:
return None
expression = ir_data_utils.reader(expression)
- if expression.WhichOneof("expression") == "constant":
+ if expression.which_expression == "constant":
return int(expression.constant.value or 0)
- elif expression.WhichOneof("expression") == "constant_reference":
+ elif expression.which_expression == "constant_reference":
# We can't look up the constant reference without the IR, but by the time
# constant_value is called, the actual values should have been propagated to
# the type information.
- if expression.type.WhichOneof("type") == "integer":
+ if expression.type.which_type == "integer":
assert expression.type.integer.modulus == "infinity"
return int(expression.type.integer.modular_value)
- elif expression.type.WhichOneof("type") == "boolean":
+ elif expression.type.which_type == "boolean":
assert expression.type.boolean.HasField("value")
return expression.type.boolean.value
- elif expression.type.WhichOneof("type") == "enumeration":
+ elif expression.type.which_type == "enumeration":
assert expression.type.enumeration.HasField("value")
return int(expression.type.enumeration.value)
else:
assert False, "Unexpected expression type {}".format(
- expression.type.WhichOneof("type")
+ expression.type.which_type
)
- elif expression.WhichOneof("expression") == "function":
+ elif expression.which_expression == "function":
return _constant_value_of_function(expression.function, bindings)
- elif expression.WhichOneof("expression") == "field_reference":
+ elif expression.which_expression == "field_reference":
return None
- elif expression.WhichOneof("expression") == "boolean_constant":
+ elif expression.which_expression == "boolean_constant":
return expression.boolean_constant.value
- elif expression.WhichOneof("expression") == "builtin_reference":
+ elif expression.which_expression == "builtin_reference":
name = expression.builtin_reference.canonical_name.object_path[0]
if bindings and name in bindings:
return bindings[name]
else:
return None
- elif expression.WhichOneof("expression") is None:
+ elif expression.which_expression is None:
return None
else:
assert False, "Unexpected expression kind {}".format(
- expression.WhichOneof("expression")
+ expression.which_expression
)
@@ -366,11 +366,11 @@
"""
array_multiplier = 1
while type_ir.HasField("array_type"):
- if type_ir.array_type.WhichOneof("size") == "automatic":
+ if type_ir.array_type.which_size == "automatic":
return None
else:
assert (
- type_ir.array_type.WhichOneof("size") == "element_count"
+ type_ir.array_type.which_size == "element_count"
), 'Expected array size to be "automatic" or "element_count".'
element_count = type_ir.array_type.element_count
if not is_constant(element_count):