Allow all comparison operators for enums.
diff --git a/compiler/back_end/cpp/testcode/condition_test.cc b/compiler/back_end/cpp/testcode/condition_test.cc index 588a6f7..1645f48 100644 --- a/compiler/back_end/cpp/testcode/condition_test.cc +++ b/compiler/back_end/cpp/testcode/condition_test.cc
@@ -439,7 +439,10 @@ EXPECT_TRUE(writer.Ok()); ASSERT_TRUE(writer.SizeIsKnown()); EXPECT_EQ(2U, writer.SizeInBytes()); + EXPECT_TRUE(writer.has_xc().Value()); EXPECT_EQ(0, writer.xc().Read()); + EXPECT_TRUE(writer.has_xc2().Value()); + EXPECT_EQ(0, writer.xc2().Read()); } TEST(Conditional, FalseEnumBasedCondition) { @@ -449,6 +452,9 @@ ASSERT_TRUE(writer.SizeIsKnown()); EXPECT_EQ(1U, writer.SizeInBytes()); EXPECT_FALSE(writer.xc().Ok()); + EXPECT_FALSE(writer.has_xc().Value()); + EXPECT_FALSE(writer.xc2().Ok()); + EXPECT_FALSE(writer.has_xc2().Value()); } TEST(Conditional, TrueEnumBasedNegativeCondition) {
diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py index 5369ca6..edc9d93 100644 --- a/compiler/front_end/type_check.py +++ b/compiler/front_end/type_check.py
@@ -111,8 +111,10 @@ for arg in expression.function.args: _type_check_expression(arg, source_file_name, ir, errors) function = expression.function.function - if function in (ir_pb2.Function.EQUALITY, ir_pb2.Function.INEQUALITY): - _type_check_equality_operator(expression, source_file_name, errors) + if function in (ir_pb2.Function.EQUALITY, ir_pb2.Function.INEQUALITY, + ir_pb2.Function.LESS, ir_pb2.Function.LESS_OR_EQUAL, + ir_pb2.Function.GREATER, ir_pb2.Function.GREATER_OR_EQUAL): + _type_check_comparison_operator(expression, source_file_name, errors) elif function == ir_pb2.Function.CHOICE: _type_check_choice_operator(expression, source_file_name, errors) else: @@ -138,13 +140,6 @@ "operator"), ir_pb2.Function.AND: (bool_result, bool_args, binary, 2, 2, "operator"), ir_pb2.Function.OR: (bool_result, bool_args, binary, 2, 2, "operator"), - ir_pb2.Function.LESS: (bool_result, int_args, binary, 2, 2, "operator"), - ir_pb2.Function.LESS_OR_EQUAL: (bool_result, int_args, binary, 2, 2, - "operator"), - ir_pb2.Function.GREATER: (bool_result, int_args, binary, 2, 2, - "operator"), - ir_pb2.Function.GREATER_OR_EQUAL: (bool_result, int_args, binary, 2, 2, - "operator"), ir_pb2.Function.MAXIMUM: (int_result, int_args, n_ary, 1, None, "function"), ir_pb2.Function.PRESENCE: (bool_result, field_args, n_ary, 1, 1, @@ -268,18 +263,28 @@ assert False, "_types_are_compatible works with enums, integers, booleans." -def _type_check_equality_operator(expression, source_file_name, errors): - """Checks the type of an equality operator (== or !=).""" +def _type_check_comparison_operator(expression, source_file_name, errors): + """Checks the type of a comparison operator (==, !=, <, >, >=, <=).""" + # Applying less than or greater than to a boolean is likely a mistake, so + # only equality and inequality are allowed for booleans. + if expression.function.function in (ir_pb2.Function.EQUALITY, + ir_pb2.Function.INEQUALITY): + acceptable_types = ("integer", "boolean", "enumeration") + acceptable_types_for_humans = "an integer, boolean, or enum" + else: + acceptable_types = ("integer", "enumeration") + acceptable_types_for_humans = "an integer or enum" left = expression.function.args[0] - if left.type.WhichOneof("type") not in ("integer", "boolean", "enumeration"): - errors.append([ - error.error(source_file_name, left.source_location, - "Left argument of operator '{}' must be an integer, " - "boolean, or enum.".format( - expression.function.function_name.text)) - ]) - return right = expression.function.args[1] + for (argument, name) in ((left, "Left"), (right, "Right")): + if argument.type.WhichOneof("type") not in acceptable_types: + errors.append([ + error.error(source_file_name, argument.source_location, + "{} argument of operator '{}' must be {}.".format( + name, expression.function.function_name.text, + acceptable_types_for_humans)) + ]) + return if not _types_are_compatible(left, right): errors.append([ error.error(source_file_name, expression.source_location,
diff --git a/compiler/front_end/type_check_test.py b/compiler/front_end/type_check_test.py index 6faaa25..c5f174f 100644 --- a/compiler/front_end/type_check_test.py +++ b/compiler/front_end/type_check_test.py
@@ -99,6 +99,20 @@ self.assertEqual(expression.function.args[1].type.WhichOneof("type"), "enumeration") + def test_adds_enum_comparison_operation_type(self): + ir = self._make_ir("struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+Enum.VAL>=Enum.VAL] UInt:8[] y\n" + "enum Enum:\n" + " VAL = 1\n") + self.assertEqual([], error.filter_errors(type_check.annotate_types(ir))) + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual(expression.type.WhichOneof("type"), "boolean") + self.assertEqual(expression.function.args[0].type.WhichOneof("type"), + "enumeration") + self.assertEqual(expression.function.args[1].type.WhichOneof("type"), + "enumeration") + def test_adds_integer_field_type(self): ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n" @@ -172,9 +186,9 @@ " 1 [+1==x] UInt:8[] y\n") expression = ir.module[0].type[0].structure.field[1].location.size self.assertEqual([ - [error.error("m.emb", expression.source_location, - "Both arguments of operator '==' must have the same " - "type.")] + [error.error("m.emb", expression.function.args[1].source_location, + "Right argument of operator '==' must be an integer, " + "boolean, or enum.")] ], error.filter_errors(type_check.annotate_types(ir))) def test_error_on_equality_mismatched_operands_int_bool(self): @@ -188,14 +202,27 @@ "type.")] ], error.filter_errors(type_check.annotate_types(ir))) - def test_error_on_equality_mismatched_operands_bool_int(self): + def test_error_on_mismatched_comparison_operands(self): ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" - " 1 [+true==1] UInt:8[] y\n") + " 0 [+1] UInt:8 x\n" + " 1 [+x>=Bar.BAR] UInt:8[] y\n" + "enum Bar:\n" + " BAR = 1\n") expression = ir.module[0].type[0].structure.field[1].location.size self.assertEqual([ [error.error("m.emb", expression.source_location, - "Both arguments of operator '==' must have the same " + "Both arguments of operator '>=' must have the same " + "type.")] + ], error.filter_errors(type_check.annotate_types(ir))) + + def test_error_on_equality_mismatched_operands_bool_int(self): + ir = self._make_ir("struct Foo:\n" + " 0 [+1] UInt x\n" + " 1 [+true!=1] UInt:8[] y\n") + expression = ir.module[0].type[0].structure.field[1].location.size + self.assertEqual([ + [error.error("m.emb", expression.source_location, + "Both arguments of operator '!=' must have the same " "type.")] ], error.filter_errors(type_check.annotate_types(ir))) @@ -322,18 +349,20 @@ " 1 [+true<1] UInt:8[] y\n") expression = ir.module[0].type[0].structure.field[1].location.size self.assertEqual([ - [error.error("m.emb", expression.function.args[0].source_location, - "Left argument of operator '<' must be an integer.")] + [error.error( + "m.emb", expression.function.args[0].source_location, + "Left argument of operator '<' must be an integer or enum.")] ], error.filter_errors(type_check.annotate_types(ir))) def test_error_on_bad_right_comparison_operand_type(self): ir = self._make_ir("struct Foo:\n" - " 0 [+1] UInt x\n" + " 0 [+1] UInt x\n" " 1 [+1>=true] UInt:8[] y\n") expression = ir.module[0].type[0].structure.field[1].location.size self.assertEqual([ - [error.error("m.emb", expression.function.args[1].source_location, - "Right argument of operator '>=' must be an integer.")] + [error.error( + "m.emb", expression.function.args[1].source_location, + "Right argument of operator '>=' must be an integer or enum.")] ], error.filter_errors(type_check.annotate_types(ir))) def test_error_on_bad_boolean_operand_type(self):
diff --git a/testdata/condition.emb b/testdata/condition.emb index acfa304..4d9f2a5 100644 --- a/testdata/condition.emb +++ b/testdata/condition.emb
@@ -112,6 +112,8 @@ 0 [+1] OnOff x if x == OnOff.ON: 1 [+1] UInt xc + if x > OnOff.OFF: + 1 [+1] UInt xc2 struct NegativeEnumCondition: