blob: f562cc8aa8f4e7188556d8cde3e8efd5b886ecc3 [file] [log] [blame]
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functions for checking expression types."""
from compiler.front_end import attributes
from compiler.util import error
from compiler.util import ir_data
from compiler.util import ir_data_utils
from compiler.util import ir_util
from compiler.util import traverse_ir
def _type_check_expression(expression, source_file_name, ir, errors):
"""Checks and annotates the type of an expression and all subexpressions."""
if ir_data_utils.reader(expression).type.WhichOneof("type"):
# This expression has already been type checked.
return
expression_variety = expression.WhichOneof("expression")
if expression_variety == "constant":
_type_check_integer_constant(expression)
elif expression_variety == "constant_reference":
_type_check_constant_reference(expression, source_file_name, ir, errors)
elif expression_variety == "function":
_type_check_operation(expression, source_file_name, ir, errors)
elif expression_variety == "field_reference":
_type_check_local_reference(expression, ir, errors)
elif expression_variety == "boolean_constant":
_type_check_boolean_constant(expression)
elif expression_variety == "builtin_reference":
_type_check_builtin_reference(expression)
else:
assert False, "Unknown expression variety {!r}".format(expression_variety)
def _annotate_as_integer(expression):
ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType())
def _annotate_as_boolean(expression):
ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType())
def _type_check(
expression, source_file_name, errors, type_oneof, type_name, expression_name
):
if ir_data_utils.reader(expression).type.WhichOneof("type") != type_oneof:
errors.append(
[
error.error(
source_file_name,
expression.source_location,
"{} must be {}.".format(expression_name, type_name),
)
]
)
def _type_check_integer(expression, source_file_name, errors, expression_name):
_type_check(
expression, source_file_name, errors, "integer", "an integer", expression_name
)
def _type_check_boolean(expression, source_file_name, errors, expression_name):
_type_check(
expression, source_file_name, errors, "boolean", "a boolean", expression_name
)
def _kind_check_field_reference(expression, source_file_name, errors, expression_name):
if expression.WhichOneof("expression") != "field_reference":
errors.append(
[
error.error(
source_file_name,
expression.source_location,
"{} must be a field.".format(expression_name),
)
]
)
def _type_check_integer_constant(expression):
_annotate_as_integer(expression)
def _type_check_constant_reference(expression, source_file_name, ir, errors):
"""Annotates the type of a constant reference."""
referred_name = expression.constant_reference.canonical_name
referred_object = ir_util.find_object(referred_name, ir)
if isinstance(referred_object, ir_data.EnumValue):
ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(
expression.constant_reference
)
del expression.type.enumeration.name.canonical_name.object_path[-1]
elif isinstance(referred_object, ir_data.Field):
if not ir_util.field_is_virtual(referred_object):
errors.append(
[
error.error(
source_file_name,
expression.source_location,
"Static references to physical fields are not allowed.",
),
error.note(
referred_name.module_file,
referred_object.source_location,
"{} is a physical field.".format(referred_name.object_path[-1]),
),
]
)
return
_type_check_expression(
referred_object.read_transform, referred_name.module_file, ir, errors
)
ir_data_utils.builder(expression).type.CopyFrom(
referred_object.read_transform.type
)
else:
assert False, "Unexpected constant reference type."
def _type_check_operation(expression, source_file_name, ir, errors):
for arg in expression.function.args:
_type_check_expression(arg, source_file_name, ir, errors)
function = expression.function.function
if function in (
ir_data.FunctionMapping.EQUALITY,
ir_data.FunctionMapping.INEQUALITY,
ir_data.FunctionMapping.LESS,
ir_data.FunctionMapping.LESS_OR_EQUAL,
ir_data.FunctionMapping.GREATER,
ir_data.FunctionMapping.GREATER_OR_EQUAL,
):
_type_check_comparison_operator(expression, source_file_name, errors)
elif function == ir_data.FunctionMapping.CHOICE:
_type_check_choice_operator(expression, source_file_name, errors)
else:
_type_check_monomorphic_operator(expression, source_file_name, errors)
def _type_check_monomorphic_operator(expression, source_file_name, errors):
"""Type checks an operator that accepts only one set of argument types."""
args = expression.function.args
int_args = _type_check_integer
bool_args = _type_check_boolean
field_args = _kind_check_field_reference
int_result = _annotate_as_integer
bool_result = _annotate_as_boolean
binary = ("Left argument", "Right argument")
n_ary = ("Argument {}".format(n) for n in range(len(args)))
functions = {
ir_data.FunctionMapping.ADDITION: (
int_result,
int_args,
binary,
2,
2,
"operator",
),
ir_data.FunctionMapping.SUBTRACTION: (
int_result,
int_args,
binary,
2,
2,
"operator",
),
ir_data.FunctionMapping.MULTIPLICATION: (
int_result,
int_args,
binary,
2,
2,
"operator",
),
ir_data.FunctionMapping.AND: (bool_result, bool_args, binary, 2, 2, "operator"),
ir_data.FunctionMapping.OR: (bool_result, bool_args, binary, 2, 2, "operator"),
ir_data.FunctionMapping.MAXIMUM: (
int_result,
int_args,
n_ary,
1,
None,
"function",
),
ir_data.FunctionMapping.PRESENCE: (
bool_result,
field_args,
n_ary,
1,
1,
"function",
),
ir_data.FunctionMapping.UPPER_BOUND: (
int_result,
int_args,
n_ary,
1,
1,
"function",
),
ir_data.FunctionMapping.LOWER_BOUND: (
int_result,
int_args,
n_ary,
1,
1,
"function",
),
}
function = expression.function.function
(set_result_type, check_arg, arg_names, min_args, max_args, kind) = functions[
function
]
for argument, name in zip(args, arg_names):
assert name is not None, "Too many arguments to function!"
check_arg(
argument,
source_file_name,
errors,
"{} of {} '{}'".format(name, kind, expression.function.function_name.text),
)
if len(args) < min_args:
errors.append(
[
error.error(
source_file_name,
expression.source_location,
"{} '{}' requires {} {} argument{}.".format(
kind.title(),
expression.function.function_name.text,
"exactly" if min_args == max_args else "at least",
min_args,
"s" if min_args > 1 else "",
),
)
]
)
if max_args is not None and len(args) > max_args:
errors.append(
[
error.error(
source_file_name,
expression.source_location,
"{} '{}' requires {} {} argument{}.".format(
kind.title(),
expression.function.function_name.text,
"exactly" if min_args == max_args else "at most",
max_args,
"s" if max_args > 1 else "",
),
)
]
)
set_result_type(expression)
def _type_check_local_reference(expression, ir, errors):
"""Annotates the type of a local reference."""
referrent = ir_util.find_object(expression.field_reference.path[-1], ir)
assert referrent, "Local reference should be non-None after name resolution."
if isinstance(referrent, ir_data.RuntimeParameter):
parameter = referrent
_set_expression_type_from_physical_type_reference(
expression, parameter.physical_type_alias.atomic_type.reference, ir
)
return
field = referrent
if ir_util.field_is_virtual(field):
_type_check_expression(
field.read_transform, expression.field_reference.path[0], ir, errors
)
ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
return
if not field.type.HasField("atomic_type"):
ir_data_utils.builder(expression).type.opaque.CopyFrom(ir_data.OpaqueType())
else:
_set_expression_type_from_physical_type_reference(
expression, field.type.atomic_type.reference, ir
)
def unbounded_expression_type_for_physical_type(type_definition):
"""Gets the ExpressionType for a field of the given TypeDefinition.
Arguments:
type_definition: an ir_data.AddressableUnit.
Returns:
An ir_data.ExpressionType with the corresponding expression type filled in:
for example, [prelude].UInt will result in an ExpressionType with the
`integer` field filled in.
The returned ExpressionType will not have any bounds set.
"""
# TODO(bolms): Add a `[value_type]` attribute for `external`s.
if ir_util.get_boolean_attribute(type_definition.attribute, attributes.IS_INTEGER):
return ir_data.ExpressionType(integer=ir_data.IntegerType())
elif tuple(type_definition.name.canonical_name.object_path) == ("Flag",):
# This is a hack: the Flag type should say that it is a boolean.
return ir_data.ExpressionType(boolean=ir_data.BooleanType())
elif type_definition.HasField("enumeration"):
return ir_data.ExpressionType(
enumeration=ir_data.EnumType(
name=ir_data.Reference(
canonical_name=type_definition.name.canonical_name
)
)
)
else:
return ir_data.ExpressionType(opaque=ir_data.OpaqueType())
def _set_expression_type_from_physical_type_reference(expression, type_reference, ir):
"""Sets the type of an expression to match a physical type."""
field_type = ir_util.find_object(type_reference, ir)
assert field_type, "Field type should be non-None after name resolution."
ir_data_utils.builder(expression).type.CopyFrom(
unbounded_expression_type_for_physical_type(field_type)
)
def _annotate_parameter_type(parameter, ir, source_file_name, errors):
if parameter.physical_type_alias.WhichOneof("type") != "atomic_type":
errors.append(
[
error.error(
source_file_name,
parameter.physical_type_alias.source_location,
"Parameters cannot be arrays.",
)
]
)
return
_set_expression_type_from_physical_type_reference(
parameter, parameter.physical_type_alias.atomic_type.reference, ir
)
def _types_are_compatible(a, b):
"""Returns true if a and b have compatible types."""
if a.type.WhichOneof("type") != b.type.WhichOneof("type"):
return False
elif a.type.WhichOneof("type") == "enumeration":
return ir_util.hashable_form_of_reference(
a.type.enumeration.name
) == ir_util.hashable_form_of_reference(b.type.enumeration.name)
elif a.type.WhichOneof("type") in ("integer", "boolean"):
# All integers are compatible with integers; booleans are compatible with
# booleans
return True
else:
assert False, "_types_are_compatible works with enums, integers, booleans."
def _type_check_comparison_operator(expression, source_file_name, errors):
"""Checks the type of a comparison operator (==, !=, <, >, >=, <=)."""
# Applying less than or greater than to a boolean is likely a mistake, so
# only equality and inequality are allowed for booleans.
if expression.function.function in (
ir_data.FunctionMapping.EQUALITY,
ir_data.FunctionMapping.INEQUALITY,
):
acceptable_types = ("integer", "boolean", "enumeration")
acceptable_types_for_humans = "an integer, boolean, or enum"
else:
acceptable_types = ("integer", "enumeration")
acceptable_types_for_humans = "an integer or enum"
left = expression.function.args[0]
right = expression.function.args[1]
for argument, name in ((left, "Left"), (right, "Right")):
if argument.type.WhichOneof("type") not in acceptable_types:
errors.append(
[
error.error(
source_file_name,
argument.source_location,
"{} argument of operator '{}' must be {}.".format(
name,
expression.function.function_name.text,
acceptable_types_for_humans,
),
)
]
)
return
if not _types_are_compatible(left, right):
errors.append(
[
error.error(
source_file_name,
expression.source_location,
"Both arguments of operator '{}' must have the same "
"type.".format(expression.function.function_name.text),
)
]
)
_annotate_as_boolean(expression)
def _type_check_choice_operator(expression, source_file_name, errors):
"""Checks the type of the choice operator cond ? if_true : if_false."""
condition = expression.function.args[0]
if condition.type.WhichOneof("type") != "boolean":
errors.append(
[
error.error(
source_file_name,
condition.source_location,
"Condition of operator '?:' must be a boolean.",
)
]
)
if_true = expression.function.args[1]
if if_true.type.WhichOneof("type") not in ("integer", "boolean", "enumeration"):
errors.append(
[
error.error(
source_file_name,
if_true.source_location,
"If-true clause of operator '?:' must be an integer, "
"boolean, or enum.",
)
]
)
return
if_false = expression.function.args[2]
if not _types_are_compatible(if_true, if_false):
errors.append(
[
error.error(
source_file_name,
expression.source_location,
"The if-true and if-false clauses of operator '?:' must "
"have the same type.",
)
]
)
if if_true.type.WhichOneof("type") == "integer":
_annotate_as_integer(expression)
elif if_true.type.WhichOneof("type") == "boolean":
_annotate_as_boolean(expression)
elif if_true.type.WhichOneof("type") == "enumeration":
ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(
if_true.type.enumeration.name
)
else:
assert False, "Unexpected type for if_true."
def _type_check_boolean_constant(expression):
_annotate_as_boolean(expression)
def _type_check_builtin_reference(expression):
name = expression.builtin_reference.canonical_name.object_path[0]
if name == "$is_statically_sized":
_annotate_as_boolean(expression)
elif name == "$static_size_in_bits":
_annotate_as_integer(expression)
else:
assert False, "Unknown builtin '{}'.".format(name)
def _type_check_array_size(expression, source_file_name, errors):
_type_check_integer(expression, source_file_name, errors, "Array size")
def _type_check_field_location(location, source_file_name, errors):
_type_check_integer(location.start, source_file_name, errors, "Start of field")
_type_check_integer(location.size, source_file_name, errors, "Size of field")
def _type_check_field_existence_condition(field, source_file_name, errors):
_type_check_boolean(
field.existence_condition, source_file_name, errors, "Existence condition"
)
def _type_name_for_error_messages(expression_type):
if expression_type.WhichOneof("type") == "integer":
return "integer"
elif expression_type.WhichOneof("type") == "enumeration":
# TODO(bolms): Should this be the fully-qualified name?
return expression_type.enumeration.name.canonical_name.object_path[-1]
assert False, "Shouldn't be here."
def _type_check_passed_parameters(atomic_type, ir, source_file_name, errors):
"""Checks the types of parameters to a parameterized physical type."""
referenced_type = ir_util.find_object(atomic_type.reference.canonical_name, ir)
if len(referenced_type.runtime_parameter) != len(atomic_type.runtime_parameter):
errors.append(
[
error.error(
source_file_name,
atomic_type.source_location,
"Type {} requires {} parameter{}; {} parameter{} given.".format(
referenced_type.name.name.text,
len(referenced_type.runtime_parameter),
"" if len(referenced_type.runtime_parameter) == 1 else "s",
len(atomic_type.runtime_parameter),
"" if len(atomic_type.runtime_parameter) == 1 else "s",
),
),
error.note(
atomic_type.reference.canonical_name.module_file,
referenced_type.source_location,
"Definition of type {}.".format(referenced_type.name.name.text),
),
]
)
return
for i in range(len(referenced_type.runtime_parameter)):
if referenced_type.runtime_parameter[i].type.WhichOneof("type") not in (
"integer",
"boolean",
"enumeration",
):
# _type_check_parameter will catch invalid parameter types at the
# definition site; no need for another, probably-confusing error at any
# usage sites.
continue
if atomic_type.runtime_parameter[i].type.WhichOneof(
"type"
) != referenced_type.runtime_parameter[i].type.WhichOneof("type"):
errors.append(
[
error.error(
source_file_name,
atomic_type.runtime_parameter[i].source_location,
"Parameter {} of type {} must be {}, not {}.".format(
i,
referenced_type.name.name.text,
_type_name_for_error_messages(
referenced_type.runtime_parameter[i].type
),
_type_name_for_error_messages(
atomic_type.runtime_parameter[i].type
),
),
),
error.note(
atomic_type.reference.canonical_name.module_file,
referenced_type.runtime_parameter[i].source_location,
"Parameter {} of {}.".format(i, referenced_type.name.name.text),
),
]
)
def _type_check_parameter(runtime_parameter, source_file_name, errors):
"""Checks the type of a parameter to a physical type."""
if runtime_parameter.type.WhichOneof("type") not in ("integer", "enumeration"):
errors.append(
[
error.error(
source_file_name,
runtime_parameter.physical_type_alias.source_location,
"Runtime parameters must be integer or enum.",
)
]
)
def annotate_types(ir):
"""Adds type annotations to all expressions in ir.
annotate_types adds type information to all expressions (and subexpressions)
in the IR. Additionally, it checks expressions for internal type consistency:
it will generate an error for constructs like "1 + true", where the types of
the operands are not accepted by the operator.
Arguments:
ir: an IR to which to add type annotations
Returns:
A (possibly empty) list of errors.
"""
errors = []
traverse_ir.fast_traverse_ir_top_down(
ir,
[ir_data.Expression],
_type_check_expression,
skip_descendants_of={ir_data.Expression},
parameters={"errors": errors},
)
traverse_ir.fast_traverse_ir_top_down(
ir,
[ir_data.RuntimeParameter],
_annotate_parameter_type,
parameters={"errors": errors},
)
return errors
def check_types(ir):
"""Checks that expressions within the IR have the correct top-level types.
check_types ensures that expressions at the top level have correct types; in
particular, it ensures that array sizes are integers ("UInt[true]" is not a
valid array type) and that the starts and ends of ranges are integers.
Arguments:
ir: an IR to type check.
Returns:
A (possibly empty) list of errors.
"""
errors = []
traverse_ir.fast_traverse_ir_top_down(
ir,
[ir_data.FieldLocation],
_type_check_field_location,
parameters={"errors": errors},
)
traverse_ir.fast_traverse_ir_top_down(
ir,
[ir_data.ArrayType, ir_data.Expression],
_type_check_array_size,
skip_descendants_of={ir_data.AtomicType},
parameters={"errors": errors},
)
traverse_ir.fast_traverse_ir_top_down(
ir,
[ir_data.Field],
_type_check_field_existence_condition,
parameters={"errors": errors},
)
traverse_ir.fast_traverse_ir_top_down(
ir,
[ir_data.RuntimeParameter],
_type_check_parameter,
parameters={"errors": errors},
)
traverse_ir.fast_traverse_ir_top_down(
ir,
[ir_data.AtomicType],
_type_check_passed_parameters,
parameters={"errors": errors},
)
return errors