Merge pull request #71 from reventlov/more_enum_comparisons
Allow all comparison operators for enums.
diff --git a/compiler/back_end/cpp/testcode/condition_test.cc b/compiler/back_end/cpp/testcode/condition_test.cc
index 588a6f7..1645f48 100644
--- a/compiler/back_end/cpp/testcode/condition_test.cc
+++ b/compiler/back_end/cpp/testcode/condition_test.cc
@@ -439,7 +439,10 @@
EXPECT_TRUE(writer.Ok());
ASSERT_TRUE(writer.SizeIsKnown());
EXPECT_EQ(2U, writer.SizeInBytes());
+ EXPECT_TRUE(writer.has_xc().Value());
EXPECT_EQ(0, writer.xc().Read());
+ EXPECT_TRUE(writer.has_xc2().Value());
+ EXPECT_EQ(0, writer.xc2().Read());
}
TEST(Conditional, FalseEnumBasedCondition) {
@@ -449,6 +452,9 @@
ASSERT_TRUE(writer.SizeIsKnown());
EXPECT_EQ(1U, writer.SizeInBytes());
EXPECT_FALSE(writer.xc().Ok());
+ EXPECT_FALSE(writer.has_xc().Value());
+ EXPECT_FALSE(writer.xc2().Ok());
+ EXPECT_FALSE(writer.has_xc2().Value());
}
TEST(Conditional, TrueEnumBasedNegativeCondition) {
diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py
index 5369ca6..edc9d93 100644
--- a/compiler/front_end/type_check.py
+++ b/compiler/front_end/type_check.py
@@ -111,8 +111,10 @@
for arg in expression.function.args:
_type_check_expression(arg, source_file_name, ir, errors)
function = expression.function.function
- if function in (ir_pb2.Function.EQUALITY, ir_pb2.Function.INEQUALITY):
- _type_check_equality_operator(expression, source_file_name, errors)
+ if function in (ir_pb2.Function.EQUALITY, ir_pb2.Function.INEQUALITY,
+ ir_pb2.Function.LESS, ir_pb2.Function.LESS_OR_EQUAL,
+ ir_pb2.Function.GREATER, ir_pb2.Function.GREATER_OR_EQUAL):
+ _type_check_comparison_operator(expression, source_file_name, errors)
elif function == ir_pb2.Function.CHOICE:
_type_check_choice_operator(expression, source_file_name, errors)
else:
@@ -138,13 +140,6 @@
"operator"),
ir_pb2.Function.AND: (bool_result, bool_args, binary, 2, 2, "operator"),
ir_pb2.Function.OR: (bool_result, bool_args, binary, 2, 2, "operator"),
- ir_pb2.Function.LESS: (bool_result, int_args, binary, 2, 2, "operator"),
- ir_pb2.Function.LESS_OR_EQUAL: (bool_result, int_args, binary, 2, 2,
- "operator"),
- ir_pb2.Function.GREATER: (bool_result, int_args, binary, 2, 2,
- "operator"),
- ir_pb2.Function.GREATER_OR_EQUAL: (bool_result, int_args, binary, 2, 2,
- "operator"),
ir_pb2.Function.MAXIMUM: (int_result, int_args, n_ary, 1, None,
"function"),
ir_pb2.Function.PRESENCE: (bool_result, field_args, n_ary, 1, 1,
@@ -268,18 +263,28 @@
assert False, "_types_are_compatible works with enums, integers, booleans."
-def _type_check_equality_operator(expression, source_file_name, errors):
- """Checks the type of an equality operator (== or !=)."""
+def _type_check_comparison_operator(expression, source_file_name, errors):
+ """Checks the type of a comparison operator (==, !=, <, >, >=, <=)."""
+ # Applying less than or greater than to a boolean is likely a mistake, so
+ # only equality and inequality are allowed for booleans.
+ if expression.function.function in (ir_pb2.Function.EQUALITY,
+ ir_pb2.Function.INEQUALITY):
+ acceptable_types = ("integer", "boolean", "enumeration")
+ acceptable_types_for_humans = "an integer, boolean, or enum"
+ else:
+ acceptable_types = ("integer", "enumeration")
+ acceptable_types_for_humans = "an integer or enum"
left = expression.function.args[0]
- if left.type.WhichOneof("type") not in ("integer", "boolean", "enumeration"):
- errors.append([
- error.error(source_file_name, left.source_location,
- "Left argument of operator '{}' must be an integer, "
- "boolean, or enum.".format(
- expression.function.function_name.text))
- ])
- return
right = expression.function.args[1]
+ for (argument, name) in ((left, "Left"), (right, "Right")):
+ if argument.type.WhichOneof("type") not in acceptable_types:
+ errors.append([
+ error.error(source_file_name, argument.source_location,
+ "{} argument of operator '{}' must be {}.".format(
+ name, expression.function.function_name.text,
+ acceptable_types_for_humans))
+ ])
+ return
if not _types_are_compatible(left, right):
errors.append([
error.error(source_file_name, expression.source_location,
diff --git a/compiler/front_end/type_check_test.py b/compiler/front_end/type_check_test.py
index 6faaa25..c5f174f 100644
--- a/compiler/front_end/type_check_test.py
+++ b/compiler/front_end/type_check_test.py
@@ -99,6 +99,20 @@
self.assertEqual(expression.function.args[1].type.WhichOneof("type"),
"enumeration")
+ def test_adds_enum_comparison_operation_type(self):
+ ir = self._make_ir("struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+Enum.VAL>=Enum.VAL] UInt:8[] y\n"
+ "enum Enum:\n"
+ " VAL = 1\n")
+ self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(expression.type.WhichOneof("type"), "boolean")
+ self.assertEqual(expression.function.args[0].type.WhichOneof("type"),
+ "enumeration")
+ self.assertEqual(expression.function.args[1].type.WhichOneof("type"),
+ "enumeration")
+
def test_adds_integer_field_type(self):
ir = self._make_ir("struct Foo:\n"
" 0 [+1] UInt x\n"
@@ -172,9 +186,9 @@
" 1 [+1==x] UInt:8[] y\n")
expression = ir.module[0].type[0].structure.field[1].location.size
self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Both arguments of operator '==' must have the same "
- "type.")]
+ [error.error("m.emb", expression.function.args[1].source_location,
+ "Right argument of operator '==' must be an integer, "
+ "boolean, or enum.")]
], error.filter_errors(type_check.annotate_types(ir)))
def test_error_on_equality_mismatched_operands_int_bool(self):
@@ -188,14 +202,27 @@
"type.")]
], error.filter_errors(type_check.annotate_types(ir)))
- def test_error_on_equality_mismatched_operands_bool_int(self):
+ def test_error_on_mismatched_comparison_operands(self):
ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+true==1] UInt:8[] y\n")
+ " 0 [+1] UInt:8 x\n"
+ " 1 [+x>=Bar.BAR] UInt:8[] y\n"
+ "enum Bar:\n"
+ " BAR = 1\n")
expression = ir.module[0].type[0].structure.field[1].location.size
self.assertEqual([
[error.error("m.emb", expression.source_location,
- "Both arguments of operator '==' must have the same "
+ "Both arguments of operator '>=' must have the same "
+ "type.")]
+ ], error.filter_errors(type_check.annotate_types(ir)))
+
+ def test_error_on_equality_mismatched_operands_bool_int(self):
+ ir = self._make_ir("struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+true!=1] UInt:8[] y\n")
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual([
+ [error.error("m.emb", expression.source_location,
+ "Both arguments of operator '!=' must have the same "
"type.")]
], error.filter_errors(type_check.annotate_types(ir)))
@@ -322,18 +349,20 @@
" 1 [+true<1] UInt:8[] y\n")
expression = ir.module[0].type[0].structure.field[1].location.size
self.assertEqual([
- [error.error("m.emb", expression.function.args[0].source_location,
- "Left argument of operator '<' must be an integer.")]
+ [error.error(
+ "m.emb", expression.function.args[0].source_location,
+ "Left argument of operator '<' must be an integer or enum.")]
], error.filter_errors(type_check.annotate_types(ir)))
def test_error_on_bad_right_comparison_operand_type(self):
ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
+ " 0 [+1] UInt x\n"
" 1 [+1>=true] UInt:8[] y\n")
expression = ir.module[0].type[0].structure.field[1].location.size
self.assertEqual([
- [error.error("m.emb", expression.function.args[1].source_location,
- "Right argument of operator '>=' must be an integer.")]
+ [error.error(
+ "m.emb", expression.function.args[1].source_location,
+ "Right argument of operator '>=' must be an integer or enum.")]
], error.filter_errors(type_check.annotate_types(ir)))
def test_error_on_bad_boolean_operand_type(self):
diff --git a/testdata/condition.emb b/testdata/condition.emb
index acfa304..4d9f2a5 100644
--- a/testdata/condition.emb
+++ b/testdata/condition.emb
@@ -112,6 +112,8 @@
0 [+1] OnOff x
if x == OnOff.ON:
1 [+1] UInt xc
+ if x > OnOff.OFF:
+ 1 [+1] UInt xc2
struct NegativeEnumCondition: