Reformat all source with Black. (#176)
diff --git a/.github/scripts/set_release_version.py b/.github/scripts/set_release_version.py
index 980a852..92c95dd 100644
--- a/.github/scripts/set_release_version.py
+++ b/.github/scripts/set_release_version.py
@@ -3,5 +3,5 @@
import datetime
now_utc = datetime.datetime.now(datetime.timezone.utc)
-version = now_utc.strftime('%Y.%m%d.%H%M%S')
-print('::set-output name=version::v{}'.format(version))
+version = now_utc.strftime("%Y.%m%d.%H%M%S")
+print("::set-output name=version::v{}".format(version))
diff --git a/compiler/back_end/__init__.py b/compiler/back_end/__init__.py
index 2c31d84..086a24e 100644
--- a/compiler/back_end/__init__.py
+++ b/compiler/back_end/__init__.py
@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
diff --git a/compiler/back_end/cpp/__init__.py b/compiler/back_end/cpp/__init__.py
index 2c31d84..086a24e 100644
--- a/compiler/back_end/cpp/__init__.py
+++ b/compiler/back_end/cpp/__init__.py
@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
diff --git a/compiler/back_end/cpp/attributes.py b/compiler/back_end/cpp/attributes.py
index 256451b..a722143 100644
--- a/compiler/back_end/cpp/attributes.py
+++ b/compiler/back_end/cpp/attributes.py
@@ -20,6 +20,7 @@
class Attribute(str, Enum):
"""Attributes available in the C++ backend."""
+
NAMESPACE = "namespace"
ENUM_CASE = "enum_case"
@@ -37,6 +38,7 @@
Each entry is a set of (Attribute, default?) tuples, the first value being
the attribute itself, the second value being a boolean value indicating
whether the attribute is allowed to be defaulted in that scope."""
+
BITS = {
# Bits may contain an enum definition.
(Attribute.ENUM_CASE, True)
@@ -47,10 +49,12 @@
ENUM_VALUE = {
(Attribute.ENUM_CASE, False),
}
- MODULE = {
- (Attribute.NAMESPACE, False),
- (Attribute.ENUM_CASE, True),
- },
+ MODULE = (
+ {
+ (Attribute.NAMESPACE, False),
+ (Attribute.ENUM_CASE, True),
+ },
+ )
STRUCT = {
# Struct may contain an enum definition.
(Attribute.ENUM_CASE, True),
diff --git a/compiler/back_end/cpp/emboss_codegen_cpp.py b/compiler/back_end/cpp/emboss_codegen_cpp.py
index 6da9f7d..474bcc1 100644
--- a/compiler/back_end/cpp/emboss_codegen_cpp.py
+++ b/compiler/back_end/cpp/emboss_codegen_cpp.py
@@ -31,72 +31,79 @@
def _parse_command_line(argv):
- """Parses the given command-line arguments."""
- parser = argparse.ArgumentParser(description="Emboss compiler C++ back end.",
- prog=argv[0])
- parser.add_argument("--input-file",
- type=str,
- help=".emb.ir file to compile.")
- parser.add_argument("--output-file",
- type=str,
- help="Write header to file. If not specified, write " +
- "header to stdout.")
- parser.add_argument("--color-output",
- default="if_tty",
- choices=["always", "never", "if_tty", "auto"],
- help="Print error messages using color. 'auto' is a "
- "synonym for 'if_tty'.")
- parser.add_argument("--cc-enum-traits",
- action=argparse.BooleanOptionalAction,
- default=True,
- help="""Controls generation of EnumTraits by the C++
- backend""")
- return parser.parse_args(argv[1:])
+ """Parses the given command-line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Emboss compiler C++ back end.", prog=argv[0]
+ )
+ parser.add_argument("--input-file", type=str, help=".emb.ir file to compile.")
+ parser.add_argument(
+ "--output-file",
+ type=str,
+ help="Write header to file. If not specified, write " + "header to stdout.",
+ )
+ parser.add_argument(
+ "--color-output",
+ default="if_tty",
+ choices=["always", "never", "if_tty", "auto"],
+ help="Print error messages using color. 'auto' is a " "synonym for 'if_tty'.",
+ )
+ parser.add_argument(
+ "--cc-enum-traits",
+ action=argparse.BooleanOptionalAction,
+ default=True,
+ help="""Controls generation of EnumTraits by the C++
+ backend""",
+ )
+ return parser.parse_args(argv[1:])
def _show_errors(errors, ir, color_output):
- """Prints errors with source code snippets."""
- source_codes = {}
- for module in ir.module:
- source_codes[module.source_file_name] = module.source_text
- use_color = (color_output == "always" or
- (color_output in ("auto", "if_tty") and
- os.isatty(sys.stderr.fileno())))
- print(error.format_errors(errors, source_codes, use_color), file=sys.stderr)
+ """Prints errors with source code snippets."""
+ source_codes = {}
+ for module in ir.module:
+ source_codes[module.source_file_name] = module.source_text
+ use_color = color_output == "always" or (
+ color_output in ("auto", "if_tty") and os.isatty(sys.stderr.fileno())
+ )
+ print(error.format_errors(errors, source_codes, use_color), file=sys.stderr)
+
def generate_headers_and_log_errors(ir, color_output, config: header_generator.Config):
- """Generates a C++ header and logs any errors.
+ """Generates a C++ header and logs any errors.
- Arguments:
- ir: EmbossIr of the module.
- color_output: "always", "never", "if_tty", "auto"
- config: Header generation configuration.
+ Arguments:
+ ir: EmbossIr of the module.
+ color_output: "always", "never", "if_tty", "auto"
+ config: Header generation configuration.
- Returns:
- A tuple of (header, errors)
- """
- header, errors = header_generator.generate_header(ir, config)
- if errors:
- _show_errors(errors, ir, color_output)
- return (header, errors)
+ Returns:
+ A tuple of (header, errors)
+ """
+ header, errors = header_generator.generate_header(ir, config)
+ if errors:
+ _show_errors(errors, ir, color_output)
+ return (header, errors)
+
def main(flags):
- if flags.input_file:
- with open(flags.input_file) as f:
- ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, f.read())
- else:
- ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, sys.stdin.read())
- config = header_generator.Config(include_enum_traits=flags.cc_enum_traits)
- header, errors = generate_headers_and_log_errors(ir, flags.color_output, config)
- if errors:
- return 1
- if flags.output_file:
- with open(flags.output_file, "w") as f:
- f.write(header)
- else:
- print(header)
- return 0
+ if flags.input_file:
+ with open(flags.input_file) as f:
+ ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, f.read())
+ else:
+ ir = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.EmbossIr, sys.stdin.read()
+ )
+ config = header_generator.Config(include_enum_traits=flags.cc_enum_traits)
+ header, errors = generate_headers_and_log_errors(ir, flags.color_output, config)
+ if errors:
+ return 1
+ if flags.output_file:
+ with open(flags.output_file, "w") as f:
+ f.write(header)
+ else:
+ print(header)
+ return 0
-if __name__ == '__main__':
- sys.exit(main(_parse_command_line(sys.argv)))
+if __name__ == "__main__":
+ sys.exit(main(_parse_command_line(sys.argv)))
diff --git a/compiler/back_end/cpp/header_generator.py b/compiler/back_end/cpp/header_generator.py
index 11fcd17..320b15f 100644
--- a/compiler/back_end/cpp/header_generator.py
+++ b/compiler/back_end/cpp/header_generator.py
@@ -34,35 +34,125 @@
from compiler.util import resources
from compiler.util import traverse_ir
-_TEMPLATES = code_template.parse_templates(resources.load(
- "compiler.back_end.cpp", "generated_code_templates"))
+_TEMPLATES = code_template.parse_templates(
+ resources.load("compiler.back_end.cpp", "generated_code_templates")
+)
-_CPP_RESERVED_WORDS = set((
- # C keywords. A few of these are not (yet) C++ keywords, but some compilers
- # accept the superset of C and C++, so we still want to avoid them.
- "asm", "auto", "break", "case", "char", "const", "continue", "default",
- "do", "double", "else", "enum", "extern", "float", "for", "fortran", "goto",
- "if", "inline", "int", "long", "register", "restrict", "return", "short",
- "signed", "sizeof", "static", "struct", "switch", "typedef", "union",
- "unsigned", "void", "volatile", "while", "_Alignas", "_Alignof", "_Atomic",
- "_Bool", "_Complex", "_Generic", "_Imaginary", "_Noreturn", "_Pragma",
- "_Static_assert", "_Thread_local",
- # The following are not technically reserved words, but collisions are
- # likely due to the standard macros.
- "complex", "imaginary", "noreturn",
- # C++ keywords that are not also C keywords.
- "alignas", "alignof", "and", "and_eq", "asm", "bitand", "bitor", "bool",
- "catch", "char16_t", "char32_t", "class", "compl", "concept", "constexpr",
- "const_cast", "decltype", "delete", "dynamic_cast", "explicit", "export",
- "false", "friend", "mutable", "namespace", "new", "noexcept", "not",
- "not_eq", "nullptr", "operator", "or", "or_eq", "private", "protected",
- "public", "reinterpret_cast", "requires", "static_assert", "static_cast",
- "template", "this", "thread_local", "throw", "true", "try", "typeid",
- "typename", "using", "virtual", "wchar_t", "xor", "xor_eq",
- # "NULL" is not a keyword, but is still very likely to cause problems if
- # used as a namespace name.
- "NULL",
-))
+_CPP_RESERVED_WORDS = set(
+ (
+ # C keywords. A few of these are not (yet) C++ keywords, but some compilers
+ # accept the superset of C and C++, so we still want to avoid them.
+ "asm",
+ "auto",
+ "break",
+ "case",
+ "char",
+ "const",
+ "continue",
+ "default",
+ "do",
+ "double",
+ "else",
+ "enum",
+ "extern",
+ "float",
+ "for",
+ "fortran",
+ "goto",
+ "if",
+ "inline",
+ "int",
+ "long",
+ "register",
+ "restrict",
+ "return",
+ "short",
+ "signed",
+ "sizeof",
+ "static",
+ "struct",
+ "switch",
+ "typedef",
+ "union",
+ "unsigned",
+ "void",
+ "volatile",
+ "while",
+ "_Alignas",
+ "_Alignof",
+ "_Atomic",
+ "_Bool",
+ "_Complex",
+ "_Generic",
+ "_Imaginary",
+ "_Noreturn",
+ "_Pragma",
+ "_Static_assert",
+ "_Thread_local",
+ # The following are not technically reserved words, but collisions are
+ # likely due to the standard macros.
+ "complex",
+ "imaginary",
+ "noreturn",
+ # C++ keywords that are not also C keywords.
+ "alignas",
+ "alignof",
+ "and",
+ "and_eq",
+ "asm",
+ "bitand",
+ "bitor",
+ "bool",
+ "catch",
+ "char16_t",
+ "char32_t",
+ "class",
+ "compl",
+ "concept",
+ "constexpr",
+ "const_cast",
+ "decltype",
+ "delete",
+ "dynamic_cast",
+ "explicit",
+ "export",
+ "false",
+ "friend",
+ "mutable",
+ "namespace",
+ "new",
+ "noexcept",
+ "not",
+ "not_eq",
+ "nullptr",
+ "operator",
+ "or",
+ "or_eq",
+ "private",
+ "protected",
+ "public",
+ "reinterpret_cast",
+ "requires",
+ "static_assert",
+ "static_cast",
+ "template",
+ "this",
+ "thread_local",
+ "throw",
+ "true",
+ "try",
+ "typeid",
+ "typename",
+ "using",
+ "virtual",
+ "wchar_t",
+ "xor",
+ "xor_eq",
+ # "NULL" is not a keyword, but is still very likely to cause problems if
+ # used as a namespace name.
+ "NULL",
+ )
+)
# The support namespace, as a C++ namespace prefix. This namespace contains the
# Emboss C++ support classes.
@@ -71,7 +161,7 @@
# Regex matching a C++ namespace component. Captures component name.
_NS_COMPONENT_RE = r"(?:^\s*|::)\s*([a-zA-Z_][a-zA-Z0-9_]*)\s*(?=\s*$|::)"
# Regex matching a full C++ namespace (at least one namespace component).
-_NS_RE = fr"^\s*(?:{_NS_COMPONENT_RE})+\s*$"
+_NS_RE = rf"^\s*(?:{_NS_COMPONENT_RE})+\s*$"
# Regex matching an empty C++ namespace.
_NS_EMPTY_RE = r"^\s*$"
# Regex matching only the global C++ namespace.
@@ -87,1530 +177,1759 @@
# Verify that all supported enum cases have valid, implemented conversions.
for _enum_case in _SUPPORTED_ENUM_CASES:
- assert name_conversion.is_case_conversion_supported("SHOUTY_CASE", _enum_case)
+ assert name_conversion.is_case_conversion_supported("SHOUTY_CASE", _enum_case)
class Config(NamedTuple):
- """Configuration for C++ header generation."""
+ """Configuration for C++ header generation."""
- include_enum_traits: bool = True
- """Whether or not to include EnumTraits in the generated header."""
+ include_enum_traits: bool = True
+ """Whether or not to include EnumTraits in the generated header."""
def _get_namespace_components(namespace):
- """Gets the components of a C++ namespace
+ """Gets the components of a C++ namespace
- Examples:
- "::some::name::detail" -> ["some", "name", "detail"]
- "product::name" -> ["product", "name"]
- "simple" -> ["simple"]
+ Examples:
+ "::some::name::detail" -> ["some", "name", "detail"]
+ "product::name" -> ["product", "name"]
+ "simple" -> ["simple"]
- Arguments:
- namespace: A string containing the namespace. May be fully-qualified.
+ Arguments:
+ namespace: A string containing the namespace. May be fully-qualified.
- Returns:
- A list of strings, one per namespace component."""
- return re.findall(_NS_COMPONENT_RE, namespace)
+ Returns:
+ A list of strings, one per namespace component."""
+ return re.findall(_NS_COMPONENT_RE, namespace)
def _get_module_namespace(module):
- """Returns the C++ namespace of the module, as a list of components.
+ """Returns the C++ namespace of the module, as a list of components.
- Arguments:
- module: The IR of an Emboss module whose namespace should be returned.
+ Arguments:
+ module: The IR of an Emboss module whose namespace should be returned.
- Returns:
- A list of strings, one per namespace component. This list can be formatted
- as appropriate by the caller.
- """
- namespace_attr = ir_util.get_attribute(module.attribute, "namespace")
- if namespace_attr and namespace_attr.string_constant.text:
- namespace = namespace_attr.string_constant.text
- else:
- namespace = "emboss_generated_code"
- return _get_namespace_components(namespace)
+ Returns:
+ A list of strings, one per namespace component. This list can be formatted
+ as appropriate by the caller.
+ """
+ namespace_attr = ir_util.get_attribute(module.attribute, "namespace")
+ if namespace_attr and namespace_attr.string_constant.text:
+ namespace = namespace_attr.string_constant.text
+ else:
+ namespace = "emboss_generated_code"
+ return _get_namespace_components(namespace)
def _cpp_string_escape(string):
- return re.sub("['\"\\\\]", r"\\\0", string)
+ return re.sub("['\"\\\\]", r"\\\0", string)
def _get_includes(module, config: Config):
- """Returns the appropriate #includes based on module's imports."""
- includes = []
- for import_ in module.foreign_import:
- if import_.file_name.text:
- includes.append(
- code_template.format_template(
- _TEMPLATES.include,
- file_name=_cpp_string_escape(import_.file_name.text + ".h")))
- else:
- includes.append(
- code_template.format_template(
- _TEMPLATES.include,
- file_name=_cpp_string_escape(_PRELUDE_INCLUDE_FILE)))
- if config.include_enum_traits:
- includes.extend(
- [code_template.format_template(
- _TEMPLATES.include,
- file_name=_cpp_string_escape(file_name))
- for file_name in (_ENUM_VIEW_INCLUDE_FILE, _TEXT_UTIL_INCLUDE_FILE)
- ])
- return "".join(includes)
+ """Returns the appropriate #includes based on module's imports."""
+ includes = []
+ for import_ in module.foreign_import:
+ if import_.file_name.text:
+ includes.append(
+ code_template.format_template(
+ _TEMPLATES.include,
+ file_name=_cpp_string_escape(import_.file_name.text + ".h"),
+ )
+ )
+ else:
+ includes.append(
+ code_template.format_template(
+ _TEMPLATES.include,
+ file_name=_cpp_string_escape(_PRELUDE_INCLUDE_FILE),
+ )
+ )
+ if config.include_enum_traits:
+ includes.extend(
+ [
+ code_template.format_template(
+ _TEMPLATES.include, file_name=_cpp_string_escape(file_name)
+ )
+ for file_name in (
+ _ENUM_VIEW_INCLUDE_FILE,
+ _TEXT_UTIL_INCLUDE_FILE,
+ )
+ ]
+ )
+ return "".join(includes)
def _render_namespace_prefix(namespace):
- """Returns namespace rendered as a prefix, like ::foo::bar::baz."""
- return "".join(["::" + n for n in namespace])
+ """Returns namespace rendered as a prefix, like ::foo::bar::baz."""
+ return "".join(["::" + n for n in namespace])
def _render_integer(value):
- """Returns a C++ string representation of a constant integer."""
- integer_type = _cpp_integer_type_for_range(value, value)
- assert integer_type, ("Bug: value should never be outside [-2**63, 2**64), "
- "got {}.".format(value))
- # C++ literals are always positive. Negative constants are actually the
- # positive literal with the unary `-` operator applied.
- #
- # This means that C++ compilers for 2s-complement systems get finicky about
- # minimum integers: if you feed `-9223372036854775808` into GCC, with -Wall,
- # you get:
- #
- # warning: integer constant is so large that it is unsigned
- #
- # and Clang gives:
- #
- # warning: integer literal is too large to be represented in a signed
- # integer type, interpreting as unsigned [-Wimplicitly-unsigned-literal]
- #
- # and MSVC:
- #
- # warning C4146: unary minus operator applied to unsigned type, result
- # still unsigned
- #
- # So, workaround #1: -(2**63) must be written `(-9223372036854775807 - 1)`.
- #
- # The next problem is that MSVC (but not Clang or GCC) will pick `unsigned`
- # as the type of a literal like `2147483648`. As far as I can tell, this is a
- # violation of the C++11 standard, but it's possible that the final standard
- # has different rules. (MSVC seems to treat decimal literals the way that the
- # standard says octal and hexadecimal literals should be treated.)
- #
- # Luckily, workaround #2: we can unconditionally append `LL` to all constants
- # to force them to be interpreted as `long long` (or `unsigned long long` for
- # `ULL`-suffixed constants), and then use a narrowing cast to the appropriate
- # type, without any warnings on any major compilers.
- #
- # TODO(bolms): This suffix computation is kind of a hack.
- suffix = "U" if "uint" in integer_type else ""
- if value == -(2**63):
- return "static_cast</**/{0}>({1}LL - 1)".format(integer_type, -(2**63 - 1))
- else:
- return "static_cast</**/{0}>({1}{2}LL)".format(integer_type, value, suffix)
+ """Returns a C++ string representation of a constant integer."""
+ integer_type = _cpp_integer_type_for_range(value, value)
+ assert (
+ integer_type
+ ), "Bug: value should never be outside [-2**63, 2**64), " "got {}.".format(value)
+ # C++ literals are always positive. Negative constants are actually the
+ # positive literal with the unary `-` operator applied.
+ #
+ # This means that C++ compilers for 2s-complement systems get finicky about
+ # minimum integers: if you feed `-9223372036854775808` into GCC, with -Wall,
+ # you get:
+ #
+ # warning: integer constant is so large that it is unsigned
+ #
+ # and Clang gives:
+ #
+ # warning: integer literal is too large to be represented in a signed
+ # integer type, interpreting as unsigned [-Wimplicitly-unsigned-literal]
+ #
+ # and MSVC:
+ #
+ # warning C4146: unary minus operator applied to unsigned type, result
+ # still unsigned
+ #
+ # So, workaround #1: -(2**63) must be written `(-9223372036854775807 - 1)`.
+ #
+ # The next problem is that MSVC (but not Clang or GCC) will pick `unsigned`
+ # as the type of a literal like `2147483648`. As far as I can tell, this is a
+ # violation of the C++11 standard, but it's possible that the final standard
+ # has different rules. (MSVC seems to treat decimal literals the way that the
+ # standard says octal and hexadecimal literals should be treated.)
+ #
+ # Luckily, workaround #2: we can unconditionally append `LL` to all constants
+ # to force them to be interpreted as `long long` (or `unsigned long long` for
+ # `ULL`-suffixed constants), and then use a narrowing cast to the appropriate
+ # type, without any warnings on any major compilers.
+ #
+ # TODO(bolms): This suffix computation is kind of a hack.
+ suffix = "U" if "uint" in integer_type else ""
+ if value == -(2**63):
+ return "static_cast</**/{0}>({1}LL - 1)".format(integer_type, -(2**63 - 1))
+ else:
+ return "static_cast</**/{0}>({1}{2}LL)".format(integer_type, value, suffix)
def _maybe_type(wrapped_type):
- return "::emboss::support::Maybe</**/{}>".format(wrapped_type)
+ return "::emboss::support::Maybe</**/{}>".format(wrapped_type)
def _render_integer_for_expression(value):
- integer_type = _cpp_integer_type_for_range(value, value)
- return "{0}({1})".format(_maybe_type(integer_type), _render_integer(value))
+ integer_type = _cpp_integer_type_for_range(value, value)
+ return "{0}({1})".format(_maybe_type(integer_type), _render_integer(value))
def _wrap_in_namespace(body, namespace):
- """Returns the given body wrapped in the given namespace."""
- for component in reversed(namespace):
- body = code_template.format_template(_TEMPLATES.namespace_wrap,
- component=component,
- body=body) + "\n"
- return body
+ """Returns the given body wrapped in the given namespace."""
+ for component in reversed(namespace):
+ body = (
+ code_template.format_template(
+ _TEMPLATES.namespace_wrap, component=component, body=body
+ )
+ + "\n"
+ )
+ return body
def _get_type_size(type_ir, ir):
- size = ir_util.fixed_size_of_type_in_bits(type_ir, ir)
- assert size is not None, (
- "_get_type_size should only be called for constant-sized types.")
- return size
+ size = ir_util.fixed_size_of_type_in_bits(type_ir, ir)
+ assert (
+ size is not None
+ ), "_get_type_size should only be called for constant-sized types."
+ return size
def _offset_storage_adapter(buffer_type, alignment, static_offset):
- return "{}::template OffsetStorageType</**/{}, {}>".format(
- buffer_type, alignment, static_offset)
+ return "{}::template OffsetStorageType</**/{}, {}>".format(
+ buffer_type, alignment, static_offset
+ )
def _bytes_to_bits_convertor(buffer_type, byte_order, size):
- assert byte_order, "byte_order should not be empty."
- return "{}::BitBlock</**/{}::{}ByteOrderer<typename {}>, {}>".format(
- _SUPPORT_NAMESPACE,
- _SUPPORT_NAMESPACE,
- byte_order,
- buffer_type,
- size)
+ assert byte_order, "byte_order should not be empty."
+ return "{}::BitBlock</**/{}::{}ByteOrderer<typename {}>, {}>".format(
+ _SUPPORT_NAMESPACE, _SUPPORT_NAMESPACE, byte_order, buffer_type, size
+ )
def _get_fully_qualified_namespace(name, ir):
- module = ir_util.find_object((name.module_file,), ir)
- namespace = _render_namespace_prefix(_get_module_namespace(module))
- return namespace + "".join(["::" + str(s) for s in name.object_path[:-1]])
+ module = ir_util.find_object((name.module_file,), ir)
+ namespace = _render_namespace_prefix(_get_module_namespace(module))
+ return namespace + "".join(["::" + str(s) for s in name.object_path[:-1]])
def _get_unqualified_name(name):
- return name.object_path[-1]
+ return name.object_path[-1]
def _get_fully_qualified_name(name, ir):
- return (_get_fully_qualified_namespace(name, ir) + "::" +
- _get_unqualified_name(name))
+ return _get_fully_qualified_namespace(name, ir) + "::" + _get_unqualified_name(name)
-def _get_adapted_cpp_buffer_type_for_field(type_definition, size_in_bits,
- buffer_type, byte_order,
- parent_addressable_unit):
- """Returns the adapted C++ type information needed to construct a view."""
- if (parent_addressable_unit == ir_data.AddressableUnit.BYTE and
- type_definition.addressable_unit == ir_data.AddressableUnit.BIT):
- assert byte_order
- return _bytes_to_bits_convertor(buffer_type, byte_order, size_in_bits)
- else:
- assert parent_addressable_unit == type_definition.addressable_unit, (
- "Addressable unit mismatch: {} vs {}".format(
- parent_addressable_unit,
- type_definition.addressable_unit))
- return buffer_type
+def _get_adapted_cpp_buffer_type_for_field(
+ type_definition, size_in_bits, buffer_type, byte_order, parent_addressable_unit
+):
+ """Returns the adapted C++ type information needed to construct a view."""
+ if (
+ parent_addressable_unit == ir_data.AddressableUnit.BYTE
+ and type_definition.addressable_unit == ir_data.AddressableUnit.BIT
+ ):
+ assert byte_order
+ return _bytes_to_bits_convertor(buffer_type, byte_order, size_in_bits)
+ else:
+ assert (
+ parent_addressable_unit == type_definition.addressable_unit
+ ), "Addressable unit mismatch: {} vs {}".format(
+ parent_addressable_unit, type_definition.addressable_unit
+ )
+ return buffer_type
def _get_cpp_view_type_for_type_definition(
- type_definition, size, ir, buffer_type, byte_order, parent_addressable_unit,
- validator):
- """Returns the C++ type information needed to construct a view.
+ type_definition,
+ size,
+ ir,
+ buffer_type,
+ byte_order,
+ parent_addressable_unit,
+ validator,
+):
+ """Returns the C++ type information needed to construct a view.
- Returns the C++ type for a view of the given Emboss TypeDefinition, and the
- C++ types of its parameters, if any.
+ Returns the C++ type for a view of the given Emboss TypeDefinition, and the
+ C++ types of its parameters, if any.
- Arguments:
- type_definition: The ir_data.TypeDefinition whose view should be
- constructed.
- size: The size, in type_definition.addressable_units, of the instantiated
- type, or None if it is not known at compile time.
- ir: The complete IR.
- buffer_type: The C++ type to be used as the Storage parameter of the view
- (e.g., "ContiguousBuffer<...>").
- byte_order: For BIT types which are direct children of BYTE types,
- "LittleEndian", "BigEndian", or "None". Otherwise, None.
- parent_addressable_unit: The addressable_unit_size of the structure
- containing this structure.
- validator: The name of the validator type to be injected into the view.
+ Arguments:
+ type_definition: The ir_data.TypeDefinition whose view should be
+ constructed.
+ size: The size, in type_definition.addressable_units, of the instantiated
+ type, or None if it is not known at compile time.
+ ir: The complete IR.
+ buffer_type: The C++ type to be used as the Storage parameter of the view
+ (e.g., "ContiguousBuffer<...>").
+ byte_order: For BIT types which are direct children of BYTE types,
+ "LittleEndian", "BigEndian", or "None". Otherwise, None.
+ parent_addressable_unit: The addressable_unit_size of the structure
+ containing this structure.
+ validator: The name of the validator type to be injected into the view.
- Returns:
- A tuple of: the C++ view type and a (possibly-empty) list of the C++ types
- of Emboss parameters which must be passed to the view's constructor.
- """
- adapted_buffer_type = _get_adapted_cpp_buffer_type_for_field(
- type_definition, size, buffer_type, byte_order, parent_addressable_unit)
- if type_definition.HasField("external"):
- # Externals do not (yet) support runtime parameters.
- return code_template.format_template(
- _TEMPLATES.external_view_type,
- namespace=_get_fully_qualified_namespace(
- type_definition.name.canonical_name, ir),
- name=_get_unqualified_name(type_definition.name.canonical_name),
- bits=size,
- validator=validator,
- buffer_type=adapted_buffer_type), []
- elif type_definition.HasField("structure"):
- parameter_types = []
- for parameter in type_definition.runtime_parameter:
- parameter_types.append(
- _cpp_basic_type_for_expression_type(parameter.type, ir))
- return code_template.format_template(
- _TEMPLATES.structure_view_type,
- namespace=_get_fully_qualified_namespace(
- type_definition.name.canonical_name, ir),
- name=_get_unqualified_name(type_definition.name.canonical_name),
- buffer_type=adapted_buffer_type), parameter_types
- elif type_definition.HasField("enumeration"):
- return code_template.format_template(
- _TEMPLATES.enum_view_type,
- support_namespace=_SUPPORT_NAMESPACE,
- enum_type=_get_fully_qualified_name(type_definition.name.canonical_name,
- ir),
- bits=size,
- validator=validator,
- buffer_type=adapted_buffer_type), []
- else:
- assert False, "Unknown variety of type {}".format(type_definition)
+ Returns:
+ A tuple of: the C++ view type and a (possibly-empty) list of the C++ types
+ of Emboss parameters which must be passed to the view's constructor.
+ """
+ adapted_buffer_type = _get_adapted_cpp_buffer_type_for_field(
+ type_definition, size, buffer_type, byte_order, parent_addressable_unit
+ )
+ if type_definition.HasField("external"):
+ # Externals do not (yet) support runtime parameters.
+ return (
+ code_template.format_template(
+ _TEMPLATES.external_view_type,
+ namespace=_get_fully_qualified_namespace(
+ type_definition.name.canonical_name, ir
+ ),
+ name=_get_unqualified_name(type_definition.name.canonical_name),
+ bits=size,
+ validator=validator,
+ buffer_type=adapted_buffer_type,
+ ),
+ [],
+ )
+ elif type_definition.HasField("structure"):
+ parameter_types = []
+ for parameter in type_definition.runtime_parameter:
+ parameter_types.append(
+ _cpp_basic_type_for_expression_type(parameter.type, ir)
+ )
+ return (
+ code_template.format_template(
+ _TEMPLATES.structure_view_type,
+ namespace=_get_fully_qualified_namespace(
+ type_definition.name.canonical_name, ir
+ ),
+ name=_get_unqualified_name(type_definition.name.canonical_name),
+ buffer_type=adapted_buffer_type,
+ ),
+ parameter_types,
+ )
+ elif type_definition.HasField("enumeration"):
+ return (
+ code_template.format_template(
+ _TEMPLATES.enum_view_type,
+ support_namespace=_SUPPORT_NAMESPACE,
+ enum_type=_get_fully_qualified_name(
+ type_definition.name.canonical_name, ir
+ ),
+ bits=size,
+ validator=validator,
+ buffer_type=adapted_buffer_type,
+ ),
+ [],
+ )
+ else:
+ assert False, "Unknown variety of type {}".format(type_definition)
def _get_cpp_view_type_for_physical_type(
- type_ir, size, byte_order, ir, buffer_type, parent_addressable_unit,
- validator):
- """Returns the C++ type information needed to construct a field's view.
+ type_ir, size, byte_order, ir, buffer_type, parent_addressable_unit, validator
+):
+ """Returns the C++ type information needed to construct a field's view.
- Returns the C++ type of an ir_data.Type, and the C++ types of its parameters,
- if any.
+ Returns the C++ type of an ir_data.Type, and the C++ types of its parameters,
+ if any.
- Arguments:
- type_ir: The ir_data.Type whose view should be constructed.
- size: The size, in type_definition.addressable_units, of the instantiated
- type, or None if it is not known at compile time.
- byte_order: For BIT types which are direct children of BYTE types,
- "LittleEndian", "BigEndian", or "None". Otherwise, None.
- ir: The complete IR.
- buffer_type: The C++ type to be used as the Storage parameter of the view
- (e.g., "ContiguousBuffer<...>").
- parent_addressable_unit: The addressable_unit_size of the structure
- containing this type.
- validator: The name of the validator type to be injected into the view.
+ Arguments:
+ type_ir: The ir_data.Type whose view should be constructed.
+ size: The size, in type_definition.addressable_units, of the instantiated
+ type, or None if it is not known at compile time.
+ byte_order: For BIT types which are direct children of BYTE types,
+ "LittleEndian", "BigEndian", or "None". Otherwise, None.
+ ir: The complete IR.
+ buffer_type: The C++ type to be used as the Storage parameter of the view
+ (e.g., "ContiguousBuffer<...>").
+ parent_addressable_unit: The addressable_unit_size of the structure
+ containing this type.
+ validator: The name of the validator type to be injected into the view.
- Returns:
- A tuple of: the C++ type for a view of the given Emboss Type and a list of
- the C++ types of any parameters of the view type, which should be passed
- to the view's constructor.
- """
- if ir_util.is_array(type_ir):
- # An array view is parameterized by the element's view type.
- base_type = type_ir.array_type.base_type
- element_size_in_bits = _get_type_size(base_type, ir)
- assert element_size_in_bits, (
- "TODO(bolms): Implement arrays of dynamically-sized elements.")
- assert element_size_in_bits % parent_addressable_unit == 0, (
- "Array elements must fall on byte boundaries.")
- element_size = element_size_in_bits // parent_addressable_unit
- element_view_type, element_view_parameter_types, element_view_parameters = (
- _get_cpp_view_type_for_physical_type(
- base_type, element_size_in_bits, byte_order, ir,
- _offset_storage_adapter(buffer_type, element_size, 0),
- parent_addressable_unit, validator))
- return (
- code_template.format_template(
- _TEMPLATES.array_view_adapter,
- support_namespace=_SUPPORT_NAMESPACE,
- # TODO(bolms): The element size should be calculable from the field
- # size and array length.
- element_view_type=element_view_type,
- element_view_parameter_types="".join(
- ", " + p for p in element_view_parameter_types),
- element_size=element_size,
- addressable_unit_size=int(parent_addressable_unit),
- buffer_type=buffer_type),
- element_view_parameter_types,
- element_view_parameters
- )
- else:
- assert type_ir.HasField("atomic_type")
- reference = type_ir.atomic_type.reference
- referenced_type = ir_util.find_object(reference, ir)
- if parent_addressable_unit > referenced_type.addressable_unit:
- assert byte_order, repr(type_ir)
- reader, parameter_types = _get_cpp_view_type_for_type_definition(
- referenced_type, size, ir, buffer_type, byte_order,
- parent_addressable_unit, validator)
- return reader, parameter_types, list(type_ir.atomic_type.runtime_parameter)
+ Returns:
+ A tuple of: the C++ type for a view of the given Emboss Type and a list of
+ the C++ types of any parameters of the view type, which should be passed
+ to the view's constructor.
+ """
+ if ir_util.is_array(type_ir):
+ # An array view is parameterized by the element's view type.
+ base_type = type_ir.array_type.base_type
+ element_size_in_bits = _get_type_size(base_type, ir)
+ assert (
+ element_size_in_bits
+ ), "TODO(bolms): Implement arrays of dynamically-sized elements."
+ assert (
+ element_size_in_bits % parent_addressable_unit == 0
+ ), "Array elements must fall on byte boundaries."
+ element_size = element_size_in_bits // parent_addressable_unit
+ element_view_type, element_view_parameter_types, element_view_parameters = (
+ _get_cpp_view_type_for_physical_type(
+ base_type,
+ element_size_in_bits,
+ byte_order,
+ ir,
+ _offset_storage_adapter(buffer_type, element_size, 0),
+ parent_addressable_unit,
+ validator,
+ )
+ )
+ return (
+ code_template.format_template(
+ _TEMPLATES.array_view_adapter,
+ support_namespace=_SUPPORT_NAMESPACE,
+ # TODO(bolms): The element size should be calculable from the field
+ # size and array length.
+ element_view_type=element_view_type,
+ element_view_parameter_types="".join(
+ ", " + p for p in element_view_parameter_types
+ ),
+ element_size=element_size,
+ addressable_unit_size=int(parent_addressable_unit),
+ buffer_type=buffer_type,
+ ),
+ element_view_parameter_types,
+ element_view_parameters,
+ )
+ else:
+ assert type_ir.HasField("atomic_type")
+ reference = type_ir.atomic_type.reference
+ referenced_type = ir_util.find_object(reference, ir)
+ if parent_addressable_unit > referenced_type.addressable_unit:
+ assert byte_order, repr(type_ir)
+ reader, parameter_types = _get_cpp_view_type_for_type_definition(
+ referenced_type,
+ size,
+ ir,
+ buffer_type,
+ byte_order,
+ parent_addressable_unit,
+ validator,
+ )
+ return reader, parameter_types, list(type_ir.atomic_type.runtime_parameter)
def _render_variable(variable, prefix=""):
- """Renders a variable reference (e.g., `foo` or `foo.bar.baz`) in C++ code."""
- # A "variable" could be an immediate field or a subcomponent of an immediate
- # field. For either case, in C++ it is valid to just use the last component
- # of the name; it is not necessary to qualify the method with the type.
- components = []
- for component in variable:
- components.append(_cpp_field_name(component[-1]) + "()")
- components[-1] = prefix + components[-1]
- return ".".join(components)
+ """Renders a variable reference (e.g., `foo` or `foo.bar.baz`) in C++ code."""
+ # A "variable" could be an immediate field or a subcomponent of an immediate
+ # field. For either case, in C++ it is valid to just use the last component
+ # of the name; it is not necessary to qualify the method with the type.
+ components = []
+ for component in variable:
+ components.append(_cpp_field_name(component[-1]) + "()")
+ components[-1] = prefix + components[-1]
+ return ".".join(components)
def _render_enum_value(enum_type, ir):
- cpp_enum_type = _get_fully_qualified_name(enum_type.name.canonical_name, ir)
- return "{}(static_cast</**/{}>({}))".format(
- _maybe_type(cpp_enum_type), cpp_enum_type, enum_type.value)
+ cpp_enum_type = _get_fully_qualified_name(enum_type.name.canonical_name, ir)
+ return "{}(static_cast</**/{}>({}))".format(
+ _maybe_type(cpp_enum_type), cpp_enum_type, enum_type.value
+ )
def _builtin_function_name(function):
- """Returns the C++ operator name corresponding to an Emboss operator."""
- functions = {
- ir_data.FunctionMapping.ADDITION: "Sum",
- ir_data.FunctionMapping.SUBTRACTION: "Difference",
- ir_data.FunctionMapping.MULTIPLICATION: "Product",
- ir_data.FunctionMapping.EQUALITY: "Equal",
- ir_data.FunctionMapping.INEQUALITY: "NotEqual",
- ir_data.FunctionMapping.AND: "And",
- ir_data.FunctionMapping.OR: "Or",
- ir_data.FunctionMapping.LESS: "LessThan",
- ir_data.FunctionMapping.LESS_OR_EQUAL: "LessThanOrEqual",
- ir_data.FunctionMapping.GREATER: "GreaterThan",
- ir_data.FunctionMapping.GREATER_OR_EQUAL: "GreaterThanOrEqual",
- ir_data.FunctionMapping.CHOICE: "Choice",
- ir_data.FunctionMapping.MAXIMUM: "Maximum",
- }
- return functions[function]
+ """Returns the C++ operator name corresponding to an Emboss operator."""
+ functions = {
+ ir_data.FunctionMapping.ADDITION: "Sum",
+ ir_data.FunctionMapping.SUBTRACTION: "Difference",
+ ir_data.FunctionMapping.MULTIPLICATION: "Product",
+ ir_data.FunctionMapping.EQUALITY: "Equal",
+ ir_data.FunctionMapping.INEQUALITY: "NotEqual",
+ ir_data.FunctionMapping.AND: "And",
+ ir_data.FunctionMapping.OR: "Or",
+ ir_data.FunctionMapping.LESS: "LessThan",
+ ir_data.FunctionMapping.LESS_OR_EQUAL: "LessThanOrEqual",
+ ir_data.FunctionMapping.GREATER: "GreaterThan",
+ ir_data.FunctionMapping.GREATER_OR_EQUAL: "GreaterThanOrEqual",
+ ir_data.FunctionMapping.CHOICE: "Choice",
+ ir_data.FunctionMapping.MAXIMUM: "Maximum",
+ }
+ return functions[function]
def _cpp_basic_type_for_expression_type(expression_type, ir):
- """Returns the C++ basic type (int32_t, bool, etc.) for an ExpressionType."""
- if expression_type.WhichOneof("type") == "integer":
- return _cpp_integer_type_for_range(
- int(expression_type.integer.minimum_value),
- int(expression_type.integer.maximum_value))
- elif expression_type.WhichOneof("type") == "boolean":
- return "bool"
- elif expression_type.WhichOneof("type") == "enumeration":
- return _get_fully_qualified_name(
- expression_type.enumeration.name.canonical_name, ir)
- else:
- assert False, "Unknown expression type " + expression_type.WhichOneof(
- "type")
+ """Returns the C++ basic type (int32_t, bool, etc.) for an ExpressionType."""
+ if expression_type.WhichOneof("type") == "integer":
+ return _cpp_integer_type_for_range(
+ int(expression_type.integer.minimum_value),
+ int(expression_type.integer.maximum_value),
+ )
+ elif expression_type.WhichOneof("type") == "boolean":
+ return "bool"
+ elif expression_type.WhichOneof("type") == "enumeration":
+ return _get_fully_qualified_name(
+ expression_type.enumeration.name.canonical_name, ir
+ )
+ else:
+ assert False, "Unknown expression type " + expression_type.WhichOneof("type")
def _cpp_basic_type_for_expression(expression, ir):
- """Returns the C++ basic type (int32_t, bool, etc.) for an Expression."""
- return _cpp_basic_type_for_expression_type(expression.type, ir)
+ """Returns the C++ basic type (int32_t, bool, etc.) for an Expression."""
+ return _cpp_basic_type_for_expression_type(expression.type, ir)
def _cpp_integer_type_for_range(min_val, max_val):
- """Returns the appropriate C++ integer type to hold min_val up to max_val."""
- # The choice of int32_t, uint32_t, int64_t, then uint64_t is somewhat
- # arbitrary here, and might not be perfectly ideal. I (bolms@) have chosen
- # this set of types to a) minimize the number of casts that occur in
- # arithmetic expressions, and b) favor 32-bit arithmetic, which is mostly
- # "cheapest" on current (2018) systems. Signed integers are also preferred
- # over unsigned so that the C++ compiler can take advantage of undefined
- # overflow.
- for size in (32, 64):
- if min_val >= -(2**(size - 1)) and max_val <= 2**(size - 1) - 1:
- return "::std::int{}_t".format(size)
- elif min_val >= 0 and max_val <= 2**size - 1:
- return "::std::uint{}_t".format(size)
- return None
-
-
-def _cpp_integer_type_for_enum(max_bits, is_signed):
- """Returns the appropriate C++ integer type to hold an enum."""
- # This is used to determine the `X` in `enum class : X`.
- #
- # Unlike _cpp_integer_type_for_range, the type chosen here is used for actual
- # storage. Further, sizes smaller than 64 are explicitly chosen by a human
- # author, so take the smallest size that can hold the given number of bits.
- #
- # Technically, the C++ standard allows some of these sizes of integer to not
- # exist, and other sizes (say, int24_t) might exist, but in practice this set
- # is almost always available. If you're compiling for some exotic DSP that
- # uses unusual int sizes, email emboss-dev@google.com.
- for size in (8, 16, 32, 64):
- if max_bits <= size:
- return "::std::{}int{}_t".format("" if is_signed else "u", size)
- assert False, f"Invalid value {max_bits} for maximum_bits"
-
-
-def _render_builtin_operation(expression, ir, field_reader, subexpressions):
- """Renders a built-in operation (+, -, &&, etc.) into C++ code."""
- assert expression.function.function not in (
- ir_data.FunctionMapping.UPPER_BOUND, ir_data.FunctionMapping.LOWER_BOUND), (
- "UPPER_BOUND and LOWER_BOUND should be constant.")
- if expression.function.function == ir_data.FunctionMapping.PRESENCE:
- return field_reader.render_existence(expression.function.args[0],
- subexpressions)
- args = expression.function.args
- rendered_args = [
- _render_expression(arg, ir, field_reader, subexpressions).rendered
- for arg in args]
- minimum_integers = []
- maximum_integers = []
- enum_types = set()
- have_boolean_types = False
- for subexpression in [expression] + list(args):
- if subexpression.type.WhichOneof("type") == "integer":
- minimum_integers.append(int(subexpression.type.integer.minimum_value))
- maximum_integers.append(int(subexpression.type.integer.maximum_value))
- elif subexpression.type.WhichOneof("type") == "enumeration":
- enum_types.add(_cpp_basic_type_for_expression(subexpression, ir))
- elif subexpression.type.WhichOneof("type") == "boolean":
- have_boolean_types = True
- # At present, all Emboss functions other than `$has` take and return one of
- # the following:
- #
- # integers
- # integers and booleans
- # a single enum type
- # a single enum type and booleans
- # booleans
- #
- # Really, the intermediate type is only necessary for integers, but it
- # simplifies the C++ somewhat if the appropriate enum/boolean type is provided
- # as "IntermediateT" -- it means that, e.g., the choice ("?:") operator does
- # not have to have two versions, one of which casts (some of) its arguments to
- # IntermediateT, and one of which does not.
- #
- # This is not a particularly robust scheme, but it works for all of the Emboss
- # functions I (bolms@) have written and am considering (division, modulus,
- # exponentiation, logical negation, bit shifts, bitwise and/or/xor, $min,
- # $floor, $ceil, $has).
- if minimum_integers and not enum_types:
- intermediate_type = _cpp_integer_type_for_range(min(minimum_integers),
- max(maximum_integers))
- elif len(enum_types) == 1 and not minimum_integers:
- intermediate_type = list(enum_types)[0]
- else:
- assert have_boolean_types
- assert not enum_types
- assert not minimum_integers
- intermediate_type = "bool"
- arg_types = [_cpp_basic_type_for_expression(arg, ir) for arg in args]
- result_type = _cpp_basic_type_for_expression(expression, ir)
- function_variant = "</**/{}, {}, {}>".format(
- intermediate_type, result_type, ", ".join(arg_types))
- return "::emboss::support::{}{}({})".format(
- _builtin_function_name(expression.function.function),
- function_variant, ", ".join(rendered_args))
-
-
-class _FieldRenderer(object):
- """Base class for rendering field reads."""
-
- def render_field_read_with_context(self, expression, ir, prefix,
- subexpressions):
- field = (
- prefix +
- _render_variable(ir_util.hashable_form_of_field_reference(
- expression.field_reference)))
- if subexpressions is None:
- field_expression = field
- else:
- field_expression = subexpressions.add(field)
- expression_cpp_type = _cpp_basic_type_for_expression(expression, ir)
- return ("({0}.Ok()"
- " ? {1}(static_cast</**/{2}>({0}.UncheckedRead()))"
- " : {1}())".format(
- field_expression,
- _maybe_type(expression_cpp_type),
- expression_cpp_type))
-
- def render_existence_with_context(self, expression, prefix, subexpressions):
- return "{1}{0}".format(
- _render_variable(
- ir_util.hashable_form_of_field_reference(
- expression.field_reference),
- "has_"),
- prefix)
-
-
-class _DirectFieldRenderer(_FieldRenderer):
- """Renderer for fields read from inside a structure's View type."""
-
- def render_field(self, expression, ir, subexpressions):
- return self.render_field_read_with_context(
- expression, ir, "", subexpressions)
-
- def render_existence(self, expression, subexpressions):
- return self.render_existence_with_context(expression, "", subexpressions)
-
-
-class _VirtualViewFieldRenderer(_FieldRenderer):
- """Renderer for field reads from inside a virtual field's View."""
-
- def render_existence(self, expression, subexpressions):
- return self.render_existence_with_context(
- expression, "view_.", subexpressions)
-
- def render_field(self, expression, ir, subexpressions):
- return self.render_field_read_with_context(
- expression, ir, "view_.", subexpressions)
-
-
-class _SubexpressionStore(object):
- """Holder for subexpressions to be assigned to local variables."""
-
- def __init__(self, prefix):
- self._prefix = prefix
- self._subexpr_to_name = {}
- self._index_to_subexpr = []
-
- def add(self, subexpr):
- if subexpr not in self._subexpr_to_name:
- self._index_to_subexpr.append(subexpr)
- self._subexpr_to_name[subexpr] = (
- self._prefix + str(len(self._index_to_subexpr)))
- return self._subexpr_to_name[subexpr]
-
- def subexprs(self):
- return [(self._subexpr_to_name[subexpr], subexpr)
- for subexpr in self._index_to_subexpr]
-
-
-_ExpressionResult = collections.namedtuple("ExpressionResult",
- ["rendered", "is_constant"])
-
-
-def _render_expression(expression, ir, field_reader=None, subexpressions=None):
- """Renders an expression into C++ code.
-
- Arguments:
- expression: The expression to render.
- ir: The IR in which to look up references.
- field_reader: An object with render_existence and render_field methods
- appropriate for the C++ context of the expression.
- subexpressions: A _SubexpressionStore in which to put subexpressions, or
- None if subexpressions should be inline.
-
- Returns:
- A tuple of (rendered_text, is_constant), where rendered_text is C++ code
- that can be emitted, and is_constant is True if the expression is a
- compile-time constant suitable for use in a C++11 constexpr context,
- otherwise False.
- """
- if field_reader is None:
- field_reader = _DirectFieldRenderer()
-
- # If the expression is constant, there are no guarantees that subexpressions
- # will fit into C++ types, or that operator arguments and return types can fit
- # in the same type: expressions like `-0x8000_0000_0000_0000` and
- # `0x1_0000_0000_0000_0000 - 1` can appear.
- if expression.type.WhichOneof("type") == "integer":
- if expression.type.integer.modulus == "infinity":
- return _ExpressionResult(_render_integer_for_expression(int(
- expression.type.integer.modular_value)), True)
- elif expression.type.WhichOneof("type") == "boolean":
- if expression.type.boolean.HasField("value"):
- if expression.type.boolean.value:
- return _ExpressionResult(_maybe_type("bool") + "(true)", True)
- else:
- return _ExpressionResult(_maybe_type("bool") + "(false)", True)
- elif expression.type.WhichOneof("type") == "enumeration":
- if expression.type.enumeration.HasField("value"):
- return _ExpressionResult(
- _render_enum_value(expression.type.enumeration, ir), True)
- else:
- # There shouldn't be any "opaque" type expressions here.
- assert False, "Unhandled expression type {}".format(
- expression.type.WhichOneof("type"))
-
- result = None
- # Otherwise, render the operation.
- if expression.WhichOneof("expression") == "function":
- result = _render_builtin_operation(
- expression, ir, field_reader, subexpressions)
- elif expression.WhichOneof("expression") == "field_reference":
- result = field_reader.render_field(expression, ir, subexpressions)
- elif (expression.WhichOneof("expression") == "builtin_reference" and
- expression.builtin_reference.canonical_name.object_path[-1] ==
- "$logical_value"):
- return _ExpressionResult(
- _maybe_type("decltype(emboss_reserved_local_value)") +
- "(emboss_reserved_local_value)", False)
-
- # Any of the constant expression types should have been handled in the
- # previous section.
- assert result is not None, "Unable to render expression {}".format(
- str(expression))
-
- if subexpressions is None:
- return _ExpressionResult(result, False)
- else:
- return _ExpressionResult(subexpressions.add(result), False)
-
-
-def _render_existence_test(field, ir, subexpressions=None):
- return _render_expression(field.existence_condition, ir, subexpressions)
-
-
-def _alignment_of_location(location):
- constraints = location.start.type.integer
- if constraints.modulus == "infinity":
- # The C++ templates use 0 as a sentinel value meaning infinity for
- # alignment.
- return 0, constraints.modular_value
- else:
- return constraints.modulus, constraints.modular_value
-
-
-def _get_cpp_type_reader_of_field(field_ir, ir, buffer_type, validator,
- parent_addressable_unit):
- """Returns the C++ view type for a field."""
- field_size = None
- if field_ir.type.HasField("size_in_bits"):
- field_size = ir_util.constant_value(field_ir.type.size_in_bits)
- assert field_size is not None
- elif ir_util.is_constant(field_ir.location.size):
- # TODO(bolms): Normalize the IR so that this clause is unnecessary.
- field_size = (ir_util.constant_value(field_ir.location.size) *
- parent_addressable_unit)
- byte_order_attr = ir_util.get_attribute(field_ir.attribute, "byte_order")
- if byte_order_attr:
- byte_order = byte_order_attr.string_constant.text
- else:
- byte_order = ""
- field_alignment, field_offset = _alignment_of_location(field_ir.location)
- return _get_cpp_view_type_for_physical_type(
- field_ir.type, field_size, byte_order, ir,
- _offset_storage_adapter(buffer_type, field_alignment, field_offset),
- parent_addressable_unit, validator)
-
-
-def _generate_structure_field_methods(enclosing_type_name, field_ir, ir,
- parent_addressable_unit):
- if ir_util.field_is_virtual(field_ir):
- return _generate_structure_virtual_field_methods(
- enclosing_type_name, field_ir, ir)
- else:
- return _generate_structure_physical_field_methods(
- enclosing_type_name, field_ir, ir, parent_addressable_unit)
-
-
-def _generate_custom_validator_expression_for(field_ir, ir):
- """Returns a validator expression for the given field, or None."""
- requires_attr = ir_util.get_attribute(field_ir.attribute, "requires")
- if requires_attr:
- class _ValidatorFieldReader(object):
- """A "FieldReader" that translates the current field to `value`."""
-
- def render_existence(self, expression, subexpressions):
- del expression # Unused.
- assert False, "Shouldn't be here."
-
- def render_field(self, expression, ir, subexpressions):
- assert len(expression.field_reference.path) == 1
- assert (expression.field_reference.path[0].canonical_name ==
- field_ir.name.canonical_name)
- expression_cpp_type = _cpp_basic_type_for_expression(expression, ir)
- return "{}(emboss_reserved_local_value)".format(
- _maybe_type(expression_cpp_type))
-
- validation_body = _render_expression(requires_attr.expression, ir,
- _ValidatorFieldReader())
- return validation_body.rendered
- else:
+ """Returns the appropriate C++ integer type to hold min_val up to max_val."""
+ # The choice of int32_t, uint32_t, int64_t, then uint64_t is somewhat
+ # arbitrary here, and might not be perfectly ideal. I (bolms@) have chosen
+ # this set of types to a) minimize the number of casts that occur in
+ # arithmetic expressions, and b) favor 32-bit arithmetic, which is mostly
+ # "cheapest" on current (2018) systems. Signed integers are also preferred
+ # over unsigned so that the C++ compiler can take advantage of undefined
+ # overflow.
+ for size in (32, 64):
+ if min_val >= -(2 ** (size - 1)) and max_val <= 2 ** (size - 1) - 1:
+ return "::std::int{}_t".format(size)
+ elif min_val >= 0 and max_val <= 2**size - 1:
+ return "::std::uint{}_t".format(size)
return None
+def _cpp_integer_type_for_enum(max_bits, is_signed):
+ """Returns the appropriate C++ integer type to hold an enum."""
+ # This is used to determine the `X` in `enum class : X`.
+ #
+ # Unlike _cpp_integer_type_for_range, the type chosen here is used for actual
+ # storage. Further, sizes smaller than 64 are explicitly chosen by a human
+ # author, so take the smallest size that can hold the given number of bits.
+ #
+ # Technically, the C++ standard allows some of these sizes of integer to not
+ # exist, and other sizes (say, int24_t) might exist, but in practice this set
+ # is almost always available. If you're compiling for some exotic DSP that
+ # uses unusual int sizes, email emboss-dev@google.com.
+ for size in (8, 16, 32, 64):
+ if max_bits <= size:
+ return "::std::{}int{}_t".format("" if is_signed else "u", size)
+ assert False, f"Invalid value {max_bits} for maximum_bits"
+
+
+def _render_builtin_operation(expression, ir, field_reader, subexpressions):
+ """Renders a built-in operation (+, -, &&, etc.) into C++ code."""
+ assert expression.function.function not in (
+ ir_data.FunctionMapping.UPPER_BOUND,
+ ir_data.FunctionMapping.LOWER_BOUND,
+ ), "UPPER_BOUND and LOWER_BOUND should be constant."
+ if expression.function.function == ir_data.FunctionMapping.PRESENCE:
+ return field_reader.render_existence(
+ expression.function.args[0], subexpressions
+ )
+ args = expression.function.args
+ rendered_args = [
+ _render_expression(arg, ir, field_reader, subexpressions).rendered
+ for arg in args
+ ]
+ minimum_integers = []
+ maximum_integers = []
+ enum_types = set()
+ have_boolean_types = False
+ for subexpression in [expression] + list(args):
+ if subexpression.type.WhichOneof("type") == "integer":
+ minimum_integers.append(int(subexpression.type.integer.minimum_value))
+ maximum_integers.append(int(subexpression.type.integer.maximum_value))
+ elif subexpression.type.WhichOneof("type") == "enumeration":
+ enum_types.add(_cpp_basic_type_for_expression(subexpression, ir))
+ elif subexpression.type.WhichOneof("type") == "boolean":
+ have_boolean_types = True
+ # At present, all Emboss functions other than `$has` take and return one of
+ # the following:
+ #
+ # integers
+ # integers and booleans
+ # a single enum type
+ # a single enum type and booleans
+ # booleans
+ #
+ # Really, the intermediate type is only necessary for integers, but it
+ # simplifies the C++ somewhat if the appropriate enum/boolean type is provided
+ # as "IntermediateT" -- it means that, e.g., the choice ("?:") operator does
+ # not have to have two versions, one of which casts (some of) its arguments to
+ # IntermediateT, and one of which does not.
+ #
+ # This is not a particularly robust scheme, but it works for all of the Emboss
+ # functions I (bolms@) have written and am considering (division, modulus,
+ # exponentiation, logical negation, bit shifts, bitwise and/or/xor, $min,
+ # $floor, $ceil, $has).
+ if minimum_integers and not enum_types:
+ intermediate_type = _cpp_integer_type_for_range(
+ min(minimum_integers), max(maximum_integers)
+ )
+ elif len(enum_types) == 1 and not minimum_integers:
+ intermediate_type = list(enum_types)[0]
+ else:
+ assert have_boolean_types
+ assert not enum_types
+ assert not minimum_integers
+ intermediate_type = "bool"
+ arg_types = [_cpp_basic_type_for_expression(arg, ir) for arg in args]
+ result_type = _cpp_basic_type_for_expression(expression, ir)
+ function_variant = "</**/{}, {}, {}>".format(
+ intermediate_type, result_type, ", ".join(arg_types)
+ )
+ return "::emboss::support::{}{}({})".format(
+ _builtin_function_name(expression.function.function),
+ function_variant,
+ ", ".join(rendered_args),
+ )
+
+
+class _FieldRenderer(object):
+ """Base class for rendering field reads."""
+
+ def render_field_read_with_context(self, expression, ir, prefix, subexpressions):
+ field = prefix + _render_variable(
+ ir_util.hashable_form_of_field_reference(expression.field_reference)
+ )
+ if subexpressions is None:
+ field_expression = field
+ else:
+ field_expression = subexpressions.add(field)
+ expression_cpp_type = _cpp_basic_type_for_expression(expression, ir)
+ return (
+ "({0}.Ok()"
+ " ? {1}(static_cast</**/{2}>({0}.UncheckedRead()))"
+ " : {1}())".format(
+ field_expression, _maybe_type(expression_cpp_type), expression_cpp_type
+ )
+ )
+
+ def render_existence_with_context(self, expression, prefix, subexpressions):
+ return "{1}{0}".format(
+ _render_variable(
+ ir_util.hashable_form_of_field_reference(expression.field_reference),
+ "has_",
+ ),
+ prefix,
+ )
+
+
+class _DirectFieldRenderer(_FieldRenderer):
+ """Renderer for fields read from inside a structure's View type."""
+
+ def render_field(self, expression, ir, subexpressions):
+ return self.render_field_read_with_context(expression, ir, "", subexpressions)
+
+ def render_existence(self, expression, subexpressions):
+ return self.render_existence_with_context(expression, "", subexpressions)
+
+
+class _VirtualViewFieldRenderer(_FieldRenderer):
+ """Renderer for field reads from inside a virtual field's View."""
+
+ def render_existence(self, expression, subexpressions):
+ return self.render_existence_with_context(expression, "view_.", subexpressions)
+
+ def render_field(self, expression, ir, subexpressions):
+ return self.render_field_read_with_context(
+ expression, ir, "view_.", subexpressions
+ )
+
+
+class _SubexpressionStore(object):
+ """Holder for subexpressions to be assigned to local variables."""
+
+ def __init__(self, prefix):
+ self._prefix = prefix
+ self._subexpr_to_name = {}
+ self._index_to_subexpr = []
+
+ def add(self, subexpr):
+ if subexpr not in self._subexpr_to_name:
+ self._index_to_subexpr.append(subexpr)
+ self._subexpr_to_name[subexpr] = self._prefix + str(
+ len(self._index_to_subexpr)
+ )
+ return self._subexpr_to_name[subexpr]
+
+ def subexprs(self):
+ return [
+ (self._subexpr_to_name[subexpr], subexpr)
+ for subexpr in self._index_to_subexpr
+ ]
+
+
+_ExpressionResult = collections.namedtuple(
+ "ExpressionResult", ["rendered", "is_constant"]
+)
+
+
+def _render_expression(expression, ir, field_reader=None, subexpressions=None):
+ """Renders an expression into C++ code.
+
+ Arguments:
+ expression: The expression to render.
+ ir: The IR in which to look up references.
+ field_reader: An object with render_existence and render_field methods
+ appropriate for the C++ context of the expression.
+ subexpressions: A _SubexpressionStore in which to put subexpressions, or
+ None if subexpressions should be inline.
+
+ Returns:
+ A tuple of (rendered_text, is_constant), where rendered_text is C++ code
+ that can be emitted, and is_constant is True if the expression is a
+ compile-time constant suitable for use in a C++11 constexpr context,
+ otherwise False.
+ """
+ if field_reader is None:
+ field_reader = _DirectFieldRenderer()
+
+ # If the expression is constant, there are no guarantees that subexpressions
+ # will fit into C++ types, or that operator arguments and return types can fit
+ # in the same type: expressions like `-0x8000_0000_0000_0000` and
+ # `0x1_0000_0000_0000_0000 - 1` can appear.
+ if expression.type.WhichOneof("type") == "integer":
+ if expression.type.integer.modulus == "infinity":
+ return _ExpressionResult(
+ _render_integer_for_expression(
+ int(expression.type.integer.modular_value)
+ ),
+ True,
+ )
+ elif expression.type.WhichOneof("type") == "boolean":
+ if expression.type.boolean.HasField("value"):
+ if expression.type.boolean.value:
+ return _ExpressionResult(_maybe_type("bool") + "(true)", True)
+ else:
+ return _ExpressionResult(_maybe_type("bool") + "(false)", True)
+ elif expression.type.WhichOneof("type") == "enumeration":
+ if expression.type.enumeration.HasField("value"):
+ return _ExpressionResult(
+ _render_enum_value(expression.type.enumeration, ir), True
+ )
+ else:
+ # There shouldn't be any "opaque" type expressions here.
+ assert False, "Unhandled expression type {}".format(
+ expression.type.WhichOneof("type")
+ )
+
+ result = None
+ # Otherwise, render the operation.
+ if expression.WhichOneof("expression") == "function":
+ result = _render_builtin_operation(expression, ir, field_reader, subexpressions)
+ elif expression.WhichOneof("expression") == "field_reference":
+ result = field_reader.render_field(expression, ir, subexpressions)
+ elif (
+ expression.WhichOneof("expression") == "builtin_reference"
+ and expression.builtin_reference.canonical_name.object_path[-1]
+ == "$logical_value"
+ ):
+ return _ExpressionResult(
+ _maybe_type("decltype(emboss_reserved_local_value)")
+ + "(emboss_reserved_local_value)",
+ False,
+ )
+
+ # Any of the constant expression types should have been handled in the
+ # previous section.
+ assert result is not None, "Unable to render expression {}".format(str(expression))
+
+ if subexpressions is None:
+ return _ExpressionResult(result, False)
+ else:
+ return _ExpressionResult(subexpressions.add(result), False)
+
+
+def _render_existence_test(field, ir, subexpressions=None):
+ return _render_expression(field.existence_condition, ir, subexpressions)
+
+
+def _alignment_of_location(location):
+ constraints = location.start.type.integer
+ if constraints.modulus == "infinity":
+ # The C++ templates use 0 as a sentinel value meaning infinity for
+ # alignment.
+ return 0, constraints.modular_value
+ else:
+ return constraints.modulus, constraints.modular_value
+
+
+def _get_cpp_type_reader_of_field(
+ field_ir, ir, buffer_type, validator, parent_addressable_unit
+):
+ """Returns the C++ view type for a field."""
+ field_size = None
+ if field_ir.type.HasField("size_in_bits"):
+ field_size = ir_util.constant_value(field_ir.type.size_in_bits)
+ assert field_size is not None
+ elif ir_util.is_constant(field_ir.location.size):
+ # TODO(bolms): Normalize the IR so that this clause is unnecessary.
+ field_size = (
+ ir_util.constant_value(field_ir.location.size) * parent_addressable_unit
+ )
+ byte_order_attr = ir_util.get_attribute(field_ir.attribute, "byte_order")
+ if byte_order_attr:
+ byte_order = byte_order_attr.string_constant.text
+ else:
+ byte_order = ""
+ field_alignment, field_offset = _alignment_of_location(field_ir.location)
+ return _get_cpp_view_type_for_physical_type(
+ field_ir.type,
+ field_size,
+ byte_order,
+ ir,
+ _offset_storage_adapter(buffer_type, field_alignment, field_offset),
+ parent_addressable_unit,
+ validator,
+ )
+
+
+def _generate_structure_field_methods(
+ enclosing_type_name, field_ir, ir, parent_addressable_unit
+):
+ if ir_util.field_is_virtual(field_ir):
+ return _generate_structure_virtual_field_methods(
+ enclosing_type_name, field_ir, ir
+ )
+ else:
+ return _generate_structure_physical_field_methods(
+ enclosing_type_name, field_ir, ir, parent_addressable_unit
+ )
+
+
+def _generate_custom_validator_expression_for(field_ir, ir):
+ """Returns a validator expression for the given field, or None."""
+ requires_attr = ir_util.get_attribute(field_ir.attribute, "requires")
+ if requires_attr:
+
+ class _ValidatorFieldReader(object):
+ """A "FieldReader" that translates the current field to `value`."""
+
+ def render_existence(self, expression, subexpressions):
+ del expression # Unused.
+ assert False, "Shouldn't be here."
+
+ def render_field(self, expression, ir, subexpressions):
+ assert len(expression.field_reference.path) == 1
+ assert (
+ expression.field_reference.path[0].canonical_name
+ == field_ir.name.canonical_name
+ )
+ expression_cpp_type = _cpp_basic_type_for_expression(expression, ir)
+ return "{}(emboss_reserved_local_value)".format(
+ _maybe_type(expression_cpp_type)
+ )
+
+ validation_body = _render_expression(
+ requires_attr.expression, ir, _ValidatorFieldReader()
+ )
+ return validation_body.rendered
+ else:
+ return None
+
+
def _generate_validator_expression_for(field_ir, ir):
- """Returns a validator expression for the given field."""
- result = _generate_custom_validator_expression_for(field_ir, ir)
- if result is None:
- return "::emboss::support::Maybe<bool>(true)"
- return result
+ """Returns a validator expression for the given field."""
+ result = _generate_custom_validator_expression_for(field_ir, ir)
+ if result is None:
+ return "::emboss::support::Maybe<bool>(true)"
+ return result
-def _generate_structure_virtual_field_methods(enclosing_type_name, field_ir,
- ir):
- """Generates C++ code for methods for a single virtual field.
+def _generate_structure_virtual_field_methods(enclosing_type_name, field_ir, ir):
+ """Generates C++ code for methods for a single virtual field.
- Arguments:
- enclosing_type_name: The text name of the enclosing type.
- field_ir: The IR for the field to generate methods for.
- ir: The full IR for the module.
+ Arguments:
+ enclosing_type_name: The text name of the enclosing type.
+ field_ir: The IR for the field to generate methods for.
+ ir: The full IR for the module.
- Returns:
- A tuple of ("", declarations, definitions). The declarations can be
- inserted into the class definition for the enclosing type's View. Any
- definitions should be placed after the class definition. These are
- separated to satisfy C++'s declaration-before-use requirements.
- """
- if field_ir.write_method.WhichOneof("method") == "alias":
- return _generate_field_indirection(field_ir, enclosing_type_name, ir)
+ Returns:
+ A tuple of ("", declarations, definitions). The declarations can be
+ inserted into the class definition for the enclosing type's View. Any
+ definitions should be placed after the class definition. These are
+ separated to satisfy C++'s declaration-before-use requirements.
+ """
+ if field_ir.write_method.WhichOneof("method") == "alias":
+ return _generate_field_indirection(field_ir, enclosing_type_name, ir)
- read_subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_")
- read_value = _render_expression(
- field_ir.read_transform, ir,
- field_reader=_VirtualViewFieldRenderer(),
- subexpressions=read_subexpressions)
- field_exists = _render_existence_test(field_ir, ir)
- logical_type = _cpp_basic_type_for_expression(field_ir.read_transform, ir)
+ read_subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_")
+ read_value = _render_expression(
+ field_ir.read_transform,
+ ir,
+ field_reader=_VirtualViewFieldRenderer(),
+ subexpressions=read_subexpressions,
+ )
+ field_exists = _render_existence_test(field_ir, ir)
+ logical_type = _cpp_basic_type_for_expression(field_ir.read_transform, ir)
- if read_value.is_constant and field_exists.is_constant:
- assert not read_subexpressions.subexprs()
- declaration_template = (
- _TEMPLATES.structure_single_const_virtual_field_method_declarations)
- definition_template = (
- _TEMPLATES.structure_single_const_virtual_field_method_definitions)
- else:
- declaration_template = (
- _TEMPLATES.structure_single_virtual_field_method_declarations)
- definition_template = (
- _TEMPLATES.structure_single_virtual_field_method_definitions)
+ if read_value.is_constant and field_exists.is_constant:
+ assert not read_subexpressions.subexprs()
+ declaration_template = (
+ _TEMPLATES.structure_single_const_virtual_field_method_declarations
+ )
+ definition_template = (
+ _TEMPLATES.structure_single_const_virtual_field_method_definitions
+ )
+ else:
+ declaration_template = (
+ _TEMPLATES.structure_single_virtual_field_method_declarations
+ )
+ definition_template = (
+ _TEMPLATES.structure_single_virtual_field_method_definitions
+ )
- if field_ir.write_method.WhichOneof("method") == "transform":
- destination = _render_variable(
- ir_util.hashable_form_of_field_reference(
- field_ir.write_method.transform.destination))
- transform = _render_expression(
- field_ir.write_method.transform.function_body, ir,
- field_reader=_VirtualViewFieldRenderer()).rendered
- write_methods = code_template.format_template(
- _TEMPLATES.structure_single_virtual_field_write_methods,
+ if field_ir.write_method.WhichOneof("method") == "transform":
+ destination = _render_variable(
+ ir_util.hashable_form_of_field_reference(
+ field_ir.write_method.transform.destination
+ )
+ )
+ transform = _render_expression(
+ field_ir.write_method.transform.function_body,
+ ir,
+ field_reader=_VirtualViewFieldRenderer(),
+ ).rendered
+ write_methods = code_template.format_template(
+ _TEMPLATES.structure_single_virtual_field_write_methods,
+ logical_type=logical_type,
+ destination=destination,
+ transform=transform,
+ )
+ else:
+ write_methods = ""
+
+ name = field_ir.name.canonical_name.object_path[-1]
+ if name.startswith("$"):
+ name = _cpp_field_name(field_ir.name.name.text)
+ virtual_view_type_name = "EmbossReservedDollarVirtual{}View".format(name)
+ else:
+ virtual_view_type_name = "EmbossReservedVirtual{}View".format(
+ name_conversion.snake_to_camel(name)
+ )
+ assert logical_type, "Could not find appropriate C++ type for {}".format(
+ field_ir.read_transform
+ )
+ if field_ir.read_transform.type.WhichOneof("type") == "integer":
+ write_to_text_stream_function = "WriteIntegerViewToTextStream"
+ elif field_ir.read_transform.type.WhichOneof("type") == "boolean":
+ write_to_text_stream_function = "WriteBooleanViewToTextStream"
+ elif field_ir.read_transform.type.WhichOneof("type") == "enumeration":
+ write_to_text_stream_function = "WriteEnumViewToTextStream"
+ else:
+ assert False, "Unexpected read-only virtual field type {}".format(
+ field_ir.read_transform.type.WhichOneof("type")
+ )
+
+ value_is_ok = _generate_validator_expression_for(field_ir, ir)
+ declaration = code_template.format_template(
+ declaration_template,
+ visibility=_visibility_for_field(field_ir),
+ name=name,
+ virtual_view_type_name=virtual_view_type_name,
logical_type=logical_type,
- destination=destination,
- transform=transform)
- else:
- write_methods = ""
-
- name = field_ir.name.canonical_name.object_path[-1]
- if name.startswith("$"):
- name = _cpp_field_name(field_ir.name.name.text)
- virtual_view_type_name = "EmbossReservedDollarVirtual{}View".format(name)
- else:
- virtual_view_type_name = "EmbossReservedVirtual{}View".format(
- name_conversion.snake_to_camel(name))
- assert logical_type, "Could not find appropriate C++ type for {}".format(
- field_ir.read_transform)
- if field_ir.read_transform.type.WhichOneof("type") == "integer":
- write_to_text_stream_function = "WriteIntegerViewToTextStream"
- elif field_ir.read_transform.type.WhichOneof("type") == "boolean":
- write_to_text_stream_function = "WriteBooleanViewToTextStream"
- elif field_ir.read_transform.type.WhichOneof("type") == "enumeration":
- write_to_text_stream_function = "WriteEnumViewToTextStream"
- else:
- assert False, "Unexpected read-only virtual field type {}".format(
- field_ir.read_transform.type.WhichOneof("type"))
-
- value_is_ok = _generate_validator_expression_for(field_ir, ir)
- declaration = code_template.format_template(
- declaration_template,
- visibility=_visibility_for_field(field_ir),
- name=name,
- virtual_view_type_name=virtual_view_type_name,
- logical_type=logical_type,
- read_subexpressions="".join(
- [" const auto {} = {};\n".format(subexpr_name, subexpr)
- for subexpr_name, subexpr in read_subexpressions.subexprs()]
- ),
- read_value=read_value.rendered,
- write_to_text_stream_function=write_to_text_stream_function,
- parent_type=enclosing_type_name,
- write_methods=write_methods,
- value_is_ok=value_is_ok)
- definition = code_template.format_template(
- definition_template,
- name=name,
- virtual_view_type_name=virtual_view_type_name,
- logical_type=logical_type,
- read_value=read_value.rendered,
- parent_type=enclosing_type_name,
- field_exists=field_exists.rendered)
- return "", declaration, definition
+ read_subexpressions="".join(
+ [
+ " const auto {} = {};\n".format(subexpr_name, subexpr)
+ for subexpr_name, subexpr in read_subexpressions.subexprs()
+ ]
+ ),
+ read_value=read_value.rendered,
+ write_to_text_stream_function=write_to_text_stream_function,
+ parent_type=enclosing_type_name,
+ write_methods=write_methods,
+ value_is_ok=value_is_ok,
+ )
+ definition = code_template.format_template(
+ definition_template,
+ name=name,
+ virtual_view_type_name=virtual_view_type_name,
+ logical_type=logical_type,
+ read_value=read_value.rendered,
+ parent_type=enclosing_type_name,
+ field_exists=field_exists.rendered,
+ )
+ return "", declaration, definition
def _generate_validator_type_for(enclosing_type_name, field_ir, ir):
- """Returns a validator type name and definition for the given field."""
- result_expression = _generate_custom_validator_expression_for(field_ir, ir)
- if result_expression is None:
- return "::emboss::support::AllValuesAreOk", ""
+ """Returns a validator type name and definition for the given field."""
+ result_expression = _generate_custom_validator_expression_for(field_ir, ir)
+ if result_expression is None:
+ return "::emboss::support::AllValuesAreOk", ""
- field_name = field_ir.name.canonical_name.object_path[-1]
- validator_type_name = "EmbossReservedValidatorFor{}".format(
- name_conversion.snake_to_camel(field_name))
- qualified_validator_type_name = "{}::{}".format(enclosing_type_name,
- validator_type_name)
+ field_name = field_ir.name.canonical_name.object_path[-1]
+ validator_type_name = "EmbossReservedValidatorFor{}".format(
+ name_conversion.snake_to_camel(field_name)
+ )
+ qualified_validator_type_name = "{}::{}".format(
+ enclosing_type_name, validator_type_name
+ )
- validator_declaration = code_template.format_template(
- _TEMPLATES.structure_field_validator,
- name=validator_type_name,
- expression=result_expression,
- )
- validator_declaration = _wrap_in_namespace(validator_declaration,
- [enclosing_type_name])
- return qualified_validator_type_name, validator_declaration
+ validator_declaration = code_template.format_template(
+ _TEMPLATES.structure_field_validator,
+ name=validator_type_name,
+ expression=result_expression,
+ )
+ validator_declaration = _wrap_in_namespace(
+ validator_declaration, [enclosing_type_name]
+ )
+ return qualified_validator_type_name, validator_declaration
-def _generate_structure_physical_field_methods(enclosing_type_name, field_ir,
- ir, parent_addressable_unit):
- """Generates C++ code for methods for a single physical field.
+def _generate_structure_physical_field_methods(
+ enclosing_type_name, field_ir, ir, parent_addressable_unit
+):
+ """Generates C++ code for methods for a single physical field.
- Arguments:
- enclosing_type_name: The text name of the enclosing type.
- field_ir: The IR for the field to generate methods for.
- ir: The full IR for the module.
- parent_addressable_unit: The addressable unit (BIT or BYTE) of the enclosing
- structure.
+ Arguments:
+ enclosing_type_name: The text name of the enclosing type.
+ field_ir: The IR for the field to generate methods for.
+ ir: The full IR for the module.
+ parent_addressable_unit: The addressable unit (BIT or BYTE) of the enclosing
+ structure.
- Returns:
- A tuple of (declarations, definitions). The declarations can be inserted
- into the class definition for the enclosing type's View. Any definitions
- should be placed after the class definition. These are separated to satisfy
- C++'s declaration-before-use requirements.
- """
- validator_type, validator_declaration = _generate_validator_type_for(
- enclosing_type_name, field_ir, ir)
+ Returns:
+ A tuple of (declarations, definitions). The declarations can be inserted
+ into the class definition for the enclosing type's View. Any definitions
+ should be placed after the class definition. These are separated to satisfy
+ C++'s declaration-before-use requirements.
+ """
+ validator_type, validator_declaration = _generate_validator_type_for(
+ enclosing_type_name, field_ir, ir
+ )
- type_reader, unused_parameter_types, parameter_expressions = (
- _get_cpp_type_reader_of_field(field_ir, ir, "Storage", validator_type,
- parent_addressable_unit))
+ type_reader, unused_parameter_types, parameter_expressions = (
+ _get_cpp_type_reader_of_field(
+ field_ir, ir, "Storage", validator_type, parent_addressable_unit
+ )
+ )
- field_name = field_ir.name.canonical_name.object_path[-1]
+ field_name = field_ir.name.canonical_name.object_path[-1]
- subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_")
- parameter_values = []
- parameters_known = []
- for parameter in parameter_expressions:
- parameter_cpp_expr = _render_expression(
- parameter, ir, subexpressions=subexpressions)
- parameter_values.append(
- "{}.ValueOrDefault(), ".format(parameter_cpp_expr.rendered))
- parameters_known.append(
- "{}.Known() && ".format(parameter_cpp_expr.rendered))
- parameter_subexpressions = "".join(
- [" const auto {} = {};\n".format(name, subexpr)
- for name, subexpr in subexpressions.subexprs()]
- )
+ subexpressions = _SubexpressionStore("emboss_reserved_local_subexpr_")
+ parameter_values = []
+ parameters_known = []
+ for parameter in parameter_expressions:
+ parameter_cpp_expr = _render_expression(
+ parameter, ir, subexpressions=subexpressions
+ )
+ parameter_values.append(
+ "{}.ValueOrDefault(), ".format(parameter_cpp_expr.rendered)
+ )
+ parameters_known.append("{}.Known() && ".format(parameter_cpp_expr.rendered))
+ parameter_subexpressions = "".join(
+ [
+ " const auto {} = {};\n".format(name, subexpr)
+ for name, subexpr in subexpressions.subexprs()
+ ]
+ )
- first_size_and_offset_subexpr = len(subexpressions.subexprs())
- offset = _render_expression(
- field_ir.location.start, ir, subexpressions=subexpressions).rendered
- size = _render_expression(
- field_ir.location.size, ir, subexpressions=subexpressions).rendered
- size_and_offset_subexpressions = "".join(
- [" const auto {} = {};\n".format(name, subexpr)
- for name, subexpr in subexpressions.subexprs()[
- first_size_and_offset_subexpr:]]
- )
+ first_size_and_offset_subexpr = len(subexpressions.subexprs())
+ offset = _render_expression(
+ field_ir.location.start, ir, subexpressions=subexpressions
+ ).rendered
+ size = _render_expression(
+ field_ir.location.size, ir, subexpressions=subexpressions
+ ).rendered
+ size_and_offset_subexpressions = "".join(
+ [
+ " const auto {} = {};\n".format(name, subexpr)
+ for name, subexpr in subexpressions.subexprs()[
+ first_size_and_offset_subexpr:
+ ]
+ ]
+ )
- field_alignment, field_offset = _alignment_of_location(field_ir.location)
- declaration = code_template.format_template(
- _TEMPLATES.structure_single_field_method_declarations,
- type_reader=type_reader,
- visibility=_visibility_for_field(field_ir),
- name=field_name)
- definition = code_template.format_template(
- _TEMPLATES.structure_single_field_method_definitions,
- parent_type=enclosing_type_name,
- name=field_name,
- type_reader=type_reader,
- offset=offset,
- size=size,
- size_and_offset_subexpressions=size_and_offset_subexpressions,
- field_exists=_render_existence_test(field_ir, ir).rendered,
- alignment=field_alignment,
- parameters_known="".join(parameters_known),
- parameter_values="".join(parameter_values),
- parameter_subexpressions=parameter_subexpressions,
- static_offset=field_offset)
- return validator_declaration, declaration, definition
+ field_alignment, field_offset = _alignment_of_location(field_ir.location)
+ declaration = code_template.format_template(
+ _TEMPLATES.structure_single_field_method_declarations,
+ type_reader=type_reader,
+ visibility=_visibility_for_field(field_ir),
+ name=field_name,
+ )
+ definition = code_template.format_template(
+ _TEMPLATES.structure_single_field_method_definitions,
+ parent_type=enclosing_type_name,
+ name=field_name,
+ type_reader=type_reader,
+ offset=offset,
+ size=size,
+ size_and_offset_subexpressions=size_and_offset_subexpressions,
+ field_exists=_render_existence_test(field_ir, ir).rendered,
+ alignment=field_alignment,
+ parameters_known="".join(parameters_known),
+ parameter_values="".join(parameter_values),
+ parameter_subexpressions=parameter_subexpressions,
+ static_offset=field_offset,
+ )
+ return validator_declaration, declaration, definition
def _render_size_method(fields, ir):
- """Renders the Size methods of a struct or bits, using the correct templates.
+ """Renders the Size methods of a struct or bits, using the correct templates.
- Arguments:
- fields: The list of fields in the struct or bits. This is used to find the
- $size_in_bits or $size_in_bytes virtual field.
- ir: The IR to which fields belong.
+ Arguments:
+ fields: The list of fields in the struct or bits. This is used to find the
+ $size_in_bits or $size_in_bytes virtual field.
+ ir: The IR to which fields belong.
- Returns:
- A string representation of the Size methods, suitable for inclusion in an
- Emboss View class.
- """
- # The SizeInBytes(), SizeInBits(), and SizeIsKnown() methods just forward to
- # the generated IntrinsicSizeIn$_units_$() method, which returns a virtual
- # field with Read() and Ok() methods.
- #
- # TODO(bolms): Remove these shims, rename IntrinsicSizeIn$_units_$ to
- # SizeIn$_units_$, and update all callers to the new API.
- for field in fields:
- if field.name.name.text in ("$size_in_bits", "$size_in_bytes"):
- # If the read_transform and existence_condition are constant, then the
- # size is constexpr.
- if (_render_expression(field.read_transform, ir).is_constant and
- _render_expression(field.existence_condition, ir).is_constant):
- template = _TEMPLATES.constant_structure_size_method
- else:
- template = _TEMPLATES.runtime_structure_size_method
- return code_template.format_template(
- template,
- units="Bits" if field.name.name.text == "$size_in_bits" else "Bytes")
- assert False, "Expected a $size_in_bits or $size_in_bytes field."
+ Returns:
+ A string representation of the Size methods, suitable for inclusion in an
+ Emboss View class.
+ """
+ # The SizeInBytes(), SizeInBits(), and SizeIsKnown() methods just forward to
+ # the generated IntrinsicSizeIn$_units_$() method, which returns a virtual
+ # field with Read() and Ok() methods.
+ #
+ # TODO(bolms): Remove these shims, rename IntrinsicSizeIn$_units_$ to
+ # SizeIn$_units_$, and update all callers to the new API.
+ for field in fields:
+ if field.name.name.text in ("$size_in_bits", "$size_in_bytes"):
+ # If the read_transform and existence_condition are constant, then the
+ # size is constexpr.
+ if (
+ _render_expression(field.read_transform, ir).is_constant
+ and _render_expression(field.existence_condition, ir).is_constant
+ ):
+ template = _TEMPLATES.constant_structure_size_method
+ else:
+ template = _TEMPLATES.runtime_structure_size_method
+ return code_template.format_template(
+ template,
+ units="Bits" if field.name.name.text == "$size_in_bits" else "Bytes",
+ )
+ assert False, "Expected a $size_in_bits or $size_in_bytes field."
def _visibility_for_field(field_ir):
- """Returns the C++ visibility for field_ir within its parent view."""
- # Generally, the Google style guide for hand-written C++ forbids having
- # multiple public: and private: sections, but trying to conform to that bit of
- # the style guide would make this file significantly more complex.
- #
- # Alias fields are generated as simple methods that forward directly to the
- # aliased field's method:
- #
- # auto alias() const -> decltype(parent().child().aliased_subchild()) {
- # return parent().child().aliased_subchild();
- # }
- #
- # Figuring out the return type of `parent().child().aliased_subchild()` is
- # quite complex, since there are several levels of template indirection
- # involved. It is much easier to just leave it up to the C++ compiler.
- #
- # Unfortunately, the C++ compiler will complain if `parent()` is not declared
- # before `alias()`. If the `parent` field happens to be anonymous, the Google
- # style guide would put `parent()`'s declaration after `alias()`'s
- # declaration, which causes the C++ compiler to complain that `parent` is
- # unknown.
- #
- # The easy fix to this is just to declare `parent()` before `alias()`, and
- # explicitly mark `parent()` as `private` and `alias()` as `public`.
- #
- # Perhaps surprisingly, this limitation does not apply when `parent()`'s type
- # is not yet complete at the point where `alias()` is declared; I believe this
- # is because both `parent()` and `alias()` exist in a templated `class`, and
- # by the time `parent().child().aliased_subchild()` is actually resolved, the
- # compiler is instantiating the class and has the full definitions of all the
- # other classes available.
- if field_ir.name.is_anonymous:
- return "private"
- else:
- return "public"
+ """Returns the C++ visibility for field_ir within its parent view."""
+ # Generally, the Google style guide for hand-written C++ forbids having
+ # multiple public: and private: sections, but trying to conform to that bit of
+ # the style guide would make this file significantly more complex.
+ #
+ # Alias fields are generated as simple methods that forward directly to the
+ # aliased field's method:
+ #
+ # auto alias() const -> decltype(parent().child().aliased_subchild()) {
+ # return parent().child().aliased_subchild();
+ # }
+ #
+ # Figuring out the return type of `parent().child().aliased_subchild()` is
+ # quite complex, since there are several levels of template indirection
+ # involved. It is much easier to just leave it up to the C++ compiler.
+ #
+ # Unfortunately, the C++ compiler will complain if `parent()` is not declared
+ # before `alias()`. If the `parent` field happens to be anonymous, the Google
+ # style guide would put `parent()`'s declaration after `alias()`'s
+ # declaration, which causes the C++ compiler to complain that `parent` is
+ # unknown.
+ #
+ # The easy fix to this is just to declare `parent()` before `alias()`, and
+ # explicitly mark `parent()` as `private` and `alias()` as `public`.
+ #
+ # Perhaps surprisingly, this limitation does not apply when `parent()`'s type
+ # is not yet complete at the point where `alias()` is declared; I believe this
+ # is because both `parent()` and `alias()` exist in a templated `class`, and
+ # by the time `parent().child().aliased_subchild()` is actually resolved, the
+ # compiler is instantiating the class and has the full definitions of all the
+ # other classes available.
+ if field_ir.name.is_anonymous:
+ return "private"
+ else:
+ return "public"
def _generate_field_indirection(field_ir, parent_type_name, ir):
- """Renders a method which forwards to a field's view."""
- rendered_aliased_field = _render_variable(
- ir_util.hashable_form_of_field_reference(field_ir.write_method.alias))
- declaration = code_template.format_template(
- _TEMPLATES.structure_single_field_indirect_method_declarations,
- aliased_field=rendered_aliased_field,
- visibility=_visibility_for_field(field_ir),
- parent_type=parent_type_name,
- name=field_ir.name.name.text)
- definition = code_template.format_template(
- _TEMPLATES.struct_single_field_indirect_method_definitions,
- parent_type=parent_type_name,
- name=field_ir.name.name.text,
- aliased_field=rendered_aliased_field,
- field_exists=_render_existence_test(field_ir, ir).rendered)
- return "", declaration, definition
+ """Renders a method which forwards to a field's view."""
+ rendered_aliased_field = _render_variable(
+ ir_util.hashable_form_of_field_reference(field_ir.write_method.alias)
+ )
+ declaration = code_template.format_template(
+ _TEMPLATES.structure_single_field_indirect_method_declarations,
+ aliased_field=rendered_aliased_field,
+ visibility=_visibility_for_field(field_ir),
+ parent_type=parent_type_name,
+ name=field_ir.name.name.text,
+ )
+ definition = code_template.format_template(
+ _TEMPLATES.struct_single_field_indirect_method_definitions,
+ parent_type=parent_type_name,
+ name=field_ir.name.name.text,
+ aliased_field=rendered_aliased_field,
+ field_exists=_render_existence_test(field_ir, ir).rendered,
+ )
+ return "", declaration, definition
def _generate_subtype_definitions(type_ir, ir, config: Config):
- """Generates C++ code for subtypes of type_ir."""
- subtype_bodies = []
- subtype_forward_declarations = []
- subtype_method_definitions = []
- type_name = type_ir.name.name.text
- for subtype in type_ir.subtype:
- inner_defs = _generate_type_definition(subtype, ir, config)
- subtype_forward_declaration, subtype_body, subtype_methods = inner_defs
- subtype_forward_declarations.append(subtype_forward_declaration)
- subtype_bodies.append(subtype_body)
- subtype_method_definitions.append(subtype_methods)
- wrapped_forward_declarations = _wrap_in_namespace(
- "\n".join(subtype_forward_declarations), [type_name])
- wrapped_bodies = _wrap_in_namespace("\n".join(subtype_bodies), [type_name])
- wrapped_method_definitions = _wrap_in_namespace(
- "\n".join(subtype_method_definitions), [type_name])
- return (wrapped_bodies, wrapped_forward_declarations,
- wrapped_method_definitions)
+ """Generates C++ code for subtypes of type_ir."""
+ subtype_bodies = []
+ subtype_forward_declarations = []
+ subtype_method_definitions = []
+ type_name = type_ir.name.name.text
+ for subtype in type_ir.subtype:
+ inner_defs = _generate_type_definition(subtype, ir, config)
+ subtype_forward_declaration, subtype_body, subtype_methods = inner_defs
+ subtype_forward_declarations.append(subtype_forward_declaration)
+ subtype_bodies.append(subtype_body)
+ subtype_method_definitions.append(subtype_methods)
+ wrapped_forward_declarations = _wrap_in_namespace(
+ "\n".join(subtype_forward_declarations), [type_name]
+ )
+ wrapped_bodies = _wrap_in_namespace("\n".join(subtype_bodies), [type_name])
+ wrapped_method_definitions = _wrap_in_namespace(
+ "\n".join(subtype_method_definitions), [type_name]
+ )
+ return (wrapped_bodies, wrapped_forward_declarations, wrapped_method_definitions)
def _cpp_field_name(name):
- """Returns the C++ name for the given field name."""
- if name.startswith("$"):
- dollar_field_names = {
- "$size_in_bits": "IntrinsicSizeInBits",
- "$size_in_bytes": "IntrinsicSizeInBytes",
- "$max_size_in_bits": "MaxSizeInBits",
- "$min_size_in_bits": "MinSizeInBits",
- "$max_size_in_bytes": "MaxSizeInBytes",
- "$min_size_in_bytes": "MinSizeInBytes",
- }
- return dollar_field_names[name]
- else:
- return name
+ """Returns the C++ name for the given field name."""
+ if name.startswith("$"):
+ dollar_field_names = {
+ "$size_in_bits": "IntrinsicSizeInBits",
+ "$size_in_bytes": "IntrinsicSizeInBytes",
+ "$max_size_in_bits": "MaxSizeInBits",
+ "$min_size_in_bits": "MinSizeInBits",
+ "$max_size_in_bytes": "MaxSizeInBytes",
+ "$min_size_in_bytes": "MinSizeInBytes",
+ }
+ return dollar_field_names[name]
+ else:
+ return name
def _generate_structure_definition(type_ir, ir, config: Config):
- """Generates C++ for an Emboss structure (struct or bits).
+ """Generates C++ for an Emboss structure (struct or bits).
- Arguments:
- type_ir: The IR for the struct definition.
- ir: The full IR; used for type lookups.
+ Arguments:
+ type_ir: The IR for the struct definition.
+ ir: The full IR; used for type lookups.
- Returns:
- A tuple of: (forward declaration for classes, class bodies, method bodies),
- suitable for insertion into the appropriate places in the generated header.
- """
- subtype_bodies, subtype_forward_declarations, subtype_method_definitions = (
- _generate_subtype_definitions(type_ir, ir, config))
- type_name = type_ir.name.name.text
- field_helper_type_definitions = []
- field_method_declarations = []
- field_method_definitions = []
- virtual_field_type_definitions = []
- decode_field_clauses = []
- write_field_clauses = []
- ok_method_clauses = []
- equals_method_clauses = []
- unchecked_equals_method_clauses = []
- enum_using_statements = []
- parameter_fields = []
- constructor_parameters = []
- forwarded_parameters = []
- parameter_initializers = []
- parameter_copy_initializers = []
- units = {1: "Bits", 8: "Bytes"}[type_ir.addressable_unit]
+ Returns:
+ A tuple of: (forward declaration for classes, class bodies, method bodies),
+ suitable for insertion into the appropriate places in the generated header.
+ """
+ subtype_bodies, subtype_forward_declarations, subtype_method_definitions = (
+ _generate_subtype_definitions(type_ir, ir, config)
+ )
+ type_name = type_ir.name.name.text
+ field_helper_type_definitions = []
+ field_method_declarations = []
+ field_method_definitions = []
+ virtual_field_type_definitions = []
+ decode_field_clauses = []
+ write_field_clauses = []
+ ok_method_clauses = []
+ equals_method_clauses = []
+ unchecked_equals_method_clauses = []
+ enum_using_statements = []
+ parameter_fields = []
+ constructor_parameters = []
+ forwarded_parameters = []
+ parameter_initializers = []
+ parameter_copy_initializers = []
+ units = {1: "Bits", 8: "Bytes"}[type_ir.addressable_unit]
- for subtype in type_ir.subtype:
- if subtype.HasField("enumeration"):
- enum_using_statements.append(
- code_template.format_template(
- _TEMPLATES.enum_using_statement,
- component=_get_fully_qualified_name(subtype.name.canonical_name,
- ir),
- name=_get_unqualified_name(subtype.name.canonical_name)))
+ for subtype in type_ir.subtype:
+ if subtype.HasField("enumeration"):
+ enum_using_statements.append(
+ code_template.format_template(
+ _TEMPLATES.enum_using_statement,
+ component=_get_fully_qualified_name(
+ subtype.name.canonical_name, ir
+ ),
+ name=_get_unqualified_name(subtype.name.canonical_name),
+ )
+ )
- # TODO(bolms): Reorder parameter fields to optimize packing in the view type.
- for parameter in type_ir.runtime_parameter:
- parameter_type = _cpp_basic_type_for_expression_type(parameter.type, ir)
- parameter_name = parameter.name.name.text
- parameter_fields.append("{} {}_;".format(parameter_type, parameter_name))
- constructor_parameters.append(
- "{} {}, ".format(parameter_type, parameter_name))
- forwarded_parameters.append("::std::forward</**/{}>({}),".format(
- parameter_type, parameter_name))
- parameter_initializers.append(", {0}_({0})".format(parameter_name))
- parameter_copy_initializers.append(
- ", {0}_(emboss_reserved_local_other.{0}_)".format(parameter_name))
+ # TODO(bolms): Reorder parameter fields to optimize packing in the view type.
+ for parameter in type_ir.runtime_parameter:
+ parameter_type = _cpp_basic_type_for_expression_type(parameter.type, ir)
+ parameter_name = parameter.name.name.text
+ parameter_fields.append("{} {}_;".format(parameter_type, parameter_name))
+ constructor_parameters.append("{} {}, ".format(parameter_type, parameter_name))
+ forwarded_parameters.append(
+ "::std::forward</**/{}>({}),".format(parameter_type, parameter_name)
+ )
+ parameter_initializers.append(", {0}_({0})".format(parameter_name))
+ parameter_copy_initializers.append(
+ ", {0}_(emboss_reserved_local_other.{0}_)".format(parameter_name)
+ )
- field_method_declarations.append(
- code_template.format_template(
- _TEMPLATES.structure_single_parameter_field_method_declarations,
- name=parameter_name,
- logical_type=parameter_type))
- # TODO(bolms): Should parameters appear in text format?
- equals_method_clauses.append(
- code_template.format_template(_TEMPLATES.equals_method_test,
- field=parameter_name + "()"))
- unchecked_equals_method_clauses.append(
- code_template.format_template(_TEMPLATES.unchecked_equals_method_test,
- field=parameter_name + "()"))
- if type_ir.runtime_parameter:
- flag_name = "parameters_initialized_"
- parameter_copy_initializers.append(
- ", {0}(emboss_reserved_local_other.{0})".format(flag_name))
- parameters_initialized_flag = "bool {} = false;".format(flag_name)
- initialize_parameters_initialized_true = ", {}(true)".format(flag_name)
- parameter_checks = ["if (!{}) return false;".format(flag_name)]
- else:
- parameters_initialized_flag = ""
- initialize_parameters_initialized_true = ""
- parameter_checks = [""]
+ field_method_declarations.append(
+ code_template.format_template(
+ _TEMPLATES.structure_single_parameter_field_method_declarations,
+ name=parameter_name,
+ logical_type=parameter_type,
+ )
+ )
+ # TODO(bolms): Should parameters appear in text format?
+ equals_method_clauses.append(
+ code_template.format_template(
+ _TEMPLATES.equals_method_test, field=parameter_name + "()"
+ )
+ )
+ unchecked_equals_method_clauses.append(
+ code_template.format_template(
+ _TEMPLATES.unchecked_equals_method_test, field=parameter_name + "()"
+ )
+ )
+ if type_ir.runtime_parameter:
+ flag_name = "parameters_initialized_"
+ parameter_copy_initializers.append(
+ ", {0}(emboss_reserved_local_other.{0})".format(flag_name)
+ )
+ parameters_initialized_flag = "bool {} = false;".format(flag_name)
+ initialize_parameters_initialized_true = ", {}(true)".format(flag_name)
+ parameter_checks = ["if (!{}) return false;".format(flag_name)]
+ else:
+ parameters_initialized_flag = ""
+ initialize_parameters_initialized_true = ""
+ parameter_checks = [""]
- for field_index in type_ir.structure.fields_in_dependency_order:
- field = type_ir.structure.field[field_index]
- helper_types, declaration, definition = (
- _generate_structure_field_methods(
- type_name, field, ir, type_ir.addressable_unit))
- field_helper_type_definitions.append(helper_types)
- field_method_definitions.append(definition)
- ok_method_clauses.append(
- code_template.format_template(
- _TEMPLATES.ok_method_test,
- field=_cpp_field_name(field.name.name.text) + "()"))
- if not ir_util.field_is_virtual(field):
- # Virtual fields do not participate in equality tests -- they are equal by
- # definition.
- equals_method_clauses.append(
- code_template.format_template(
- _TEMPLATES.equals_method_test, field=field.name.name.text + "()"))
- unchecked_equals_method_clauses.append(
- code_template.format_template(
- _TEMPLATES.unchecked_equals_method_test,
- field=field.name.name.text + "()"))
- field_method_declarations.append(declaration)
- if not field.name.is_anonymous and not ir_util.field_is_read_only(field):
- # As above, read-only fields cannot be decoded from text format.
- decode_field_clauses.append(
- code_template.format_template(
- _TEMPLATES.decode_field,
- field_name=field.name.canonical_name.object_path[-1]))
- text_output_attr = ir_util.get_attribute(field.attribute, "text_output")
- if not text_output_attr or text_output_attr.string_constant == "Emit":
- if ir_util.field_is_read_only(field):
- write_field_template = _TEMPLATES.write_read_only_field_to_text_stream
- else:
- write_field_template = _TEMPLATES.write_field_to_text_stream
- write_field_clauses.append(
- code_template.format_template(
- write_field_template,
- field_name=field.name.canonical_name.object_path[-1]))
+ for field_index in type_ir.structure.fields_in_dependency_order:
+ field = type_ir.structure.field[field_index]
+ helper_types, declaration, definition = _generate_structure_field_methods(
+ type_name, field, ir, type_ir.addressable_unit
+ )
+ field_helper_type_definitions.append(helper_types)
+ field_method_definitions.append(definition)
+ ok_method_clauses.append(
+ code_template.format_template(
+ _TEMPLATES.ok_method_test,
+ field=_cpp_field_name(field.name.name.text) + "()",
+ )
+ )
+ if not ir_util.field_is_virtual(field):
+ # Virtual fields do not participate in equality tests -- they are equal by
+ # definition.
+ equals_method_clauses.append(
+ code_template.format_template(
+ _TEMPLATES.equals_method_test, field=field.name.name.text + "()"
+ )
+ )
+ unchecked_equals_method_clauses.append(
+ code_template.format_template(
+ _TEMPLATES.unchecked_equals_method_test,
+ field=field.name.name.text + "()",
+ )
+ )
+ field_method_declarations.append(declaration)
+ if not field.name.is_anonymous and not ir_util.field_is_read_only(field):
+ # As above, read-only fields cannot be decoded from text format.
+ decode_field_clauses.append(
+ code_template.format_template(
+ _TEMPLATES.decode_field,
+ field_name=field.name.canonical_name.object_path[-1],
+ )
+ )
+ text_output_attr = ir_util.get_attribute(field.attribute, "text_output")
+ if not text_output_attr or text_output_attr.string_constant == "Emit":
+ if ir_util.field_is_read_only(field):
+ write_field_template = _TEMPLATES.write_read_only_field_to_text_stream
+ else:
+ write_field_template = _TEMPLATES.write_field_to_text_stream
+ write_field_clauses.append(
+ code_template.format_template(
+ write_field_template,
+ field_name=field.name.canonical_name.object_path[-1],
+ )
+ )
- requires_attr = ir_util.get_attribute(type_ir.attribute, "requires")
- if requires_attr is not None:
- requires_clause = _render_expression(
- requires_attr.expression, ir, _DirectFieldRenderer()).rendered
- requires_check = (" if (!({}).ValueOr(false))\n"
- " return false;").format(requires_clause)
- else:
- requires_check = ""
+ requires_attr = ir_util.get_attribute(type_ir.attribute, "requires")
+ if requires_attr is not None:
+ requires_clause = _render_expression(
+ requires_attr.expression, ir, _DirectFieldRenderer()
+ ).rendered
+ requires_check = (
+ " if (!({}).ValueOr(false))\n" " return false;"
+ ).format(requires_clause)
+ else:
+ requires_check = ""
- if config.include_enum_traits:
- text_stream_methods = code_template.format_template(
- _TEMPLATES.struct_text_stream,
- decode_fields="\n".join(decode_field_clauses),
- write_fields="\n".join(write_field_clauses))
- else:
- text_stream_methods = ""
+ if config.include_enum_traits:
+ text_stream_methods = code_template.format_template(
+ _TEMPLATES.struct_text_stream,
+ decode_fields="\n".join(decode_field_clauses),
+ write_fields="\n".join(write_field_clauses),
+ )
+ else:
+ text_stream_methods = ""
-
- class_forward_declarations = code_template.format_template(
- _TEMPLATES.structure_view_declaration,
- name=type_name)
- class_bodies = code_template.format_template(
- _TEMPLATES.structure_view_class,
- name=type_ir.name.canonical_name.object_path[-1],
- size_method=_render_size_method(type_ir.structure.field, ir),
- field_method_declarations="".join(field_method_declarations),
- field_ok_checks="\n".join(ok_method_clauses),
- parameter_ok_checks="\n".join(parameter_checks),
- requires_check=requires_check,
- equals_method_body="\n".join(equals_method_clauses),
- unchecked_equals_method_body="\n".join(unchecked_equals_method_clauses),
- enum_usings="\n".join(enum_using_statements),
- text_stream_methods=text_stream_methods,
- parameter_fields="\n".join(parameter_fields),
- constructor_parameters="".join(constructor_parameters),
- forwarded_parameters="".join(forwarded_parameters),
- parameter_initializers="\n".join(parameter_initializers),
- parameter_copy_initializers="\n".join(parameter_copy_initializers),
- parameters_initialized_flag=parameters_initialized_flag,
- initialize_parameters_initialized_true=(
- initialize_parameters_initialized_true),
- units=units)
- method_definitions = "\n".join(field_method_definitions)
- early_virtual_field_types = "\n".join(virtual_field_type_definitions)
- all_field_helper_type_definitions = "\n".join(field_helper_type_definitions)
- return (early_virtual_field_types + subtype_forward_declarations +
- class_forward_declarations,
- all_field_helper_type_definitions + subtype_bodies + class_bodies,
- subtype_method_definitions + method_definitions)
+ class_forward_declarations = code_template.format_template(
+ _TEMPLATES.structure_view_declaration, name=type_name
+ )
+ class_bodies = code_template.format_template(
+ _TEMPLATES.structure_view_class,
+ name=type_ir.name.canonical_name.object_path[-1],
+ size_method=_render_size_method(type_ir.structure.field, ir),
+ field_method_declarations="".join(field_method_declarations),
+ field_ok_checks="\n".join(ok_method_clauses),
+ parameter_ok_checks="\n".join(parameter_checks),
+ requires_check=requires_check,
+ equals_method_body="\n".join(equals_method_clauses),
+ unchecked_equals_method_body="\n".join(unchecked_equals_method_clauses),
+ enum_usings="\n".join(enum_using_statements),
+ text_stream_methods=text_stream_methods,
+ parameter_fields="\n".join(parameter_fields),
+ constructor_parameters="".join(constructor_parameters),
+ forwarded_parameters="".join(forwarded_parameters),
+ parameter_initializers="\n".join(parameter_initializers),
+ parameter_copy_initializers="\n".join(parameter_copy_initializers),
+ parameters_initialized_flag=parameters_initialized_flag,
+ initialize_parameters_initialized_true=(initialize_parameters_initialized_true),
+ units=units,
+ )
+ method_definitions = "\n".join(field_method_definitions)
+ early_virtual_field_types = "\n".join(virtual_field_type_definitions)
+ all_field_helper_type_definitions = "\n".join(field_helper_type_definitions)
+ return (
+ early_virtual_field_types
+ + subtype_forward_declarations
+ + class_forward_declarations,
+ all_field_helper_type_definitions + subtype_bodies + class_bodies,
+ subtype_method_definitions + method_definitions,
+ )
def _split_enum_case_values_into_spans(enum_case_value):
- """Yields spans containing each enum case in an enum_case attribute value.
+ """Yields spans containing each enum case in an enum_case attribute value.
- Each span is of the form (start, end), which is the start and end position
- relative to the beginning of the enum_case_value string. To keep the grammar
- of this attribute simple, this only splits on delimiters and trims whitespace
- for each case.
+ Each span is of the form (start, end), which is the start and end position
+ relative to the beginning of the enum_case_value string. To keep the grammar
+ of this attribute simple, this only splits on delimiters and trims whitespace
+ for each case.
- Example: 'SHOUTY_CASE, kCamelCase' -> [(0, 11), (13, 23)]"""
- # Scan the string from left to right, finding commas and trimming whitespace.
- # This is essentially equivalent to (x.trim() fror x in str.split(','))
- # except that this yields spans within the string rather than the strings
- # themselves, and no span is yielded for a trailing comma.
- start, end = 0, len(enum_case_value)
- while start <= end:
- # Find a ',' delimiter to split on
- delimiter = enum_case_value.find(',', start, end)
- if delimiter < 0:
- delimiter = end
+ Example: 'SHOUTY_CASE, kCamelCase' -> [(0, 11), (13, 23)]"""
+ # Scan the string from left to right, finding commas and trimming whitespace.
+ # This is essentially equivalent to (x.trim() fror x in str.split(','))
+ # except that this yields spans within the string rather than the strings
+ # themselves, and no span is yielded for a trailing comma.
+ start, end = 0, len(enum_case_value)
+ while start <= end:
+ # Find a ',' delimiter to split on
+ delimiter = enum_case_value.find(",", start, end)
+ if delimiter < 0:
+ delimiter = end
- substr_start = start
- substr_end = delimiter
+ substr_start = start
+ substr_end = delimiter
- # Drop leading whitespace
- while (substr_start < substr_end and
- enum_case_value[substr_start].isspace()):
- substr_start += 1
- # Drop trailing whitespace
- while (substr_start < substr_end and
- enum_case_value[substr_end - 1].isspace()):
- substr_end -= 1
+ # Drop leading whitespace
+ while substr_start < substr_end and enum_case_value[substr_start].isspace():
+ substr_start += 1
+ # Drop trailing whitespace
+ while substr_start < substr_end and enum_case_value[substr_end - 1].isspace():
+ substr_end -= 1
- # Skip a trailing comma
- if substr_start == end and start != 0:
- break
+ # Skip a trailing comma
+ if substr_start == end and start != 0:
+ break
- yield substr_start, substr_end
- start = delimiter + 1
+ yield substr_start, substr_end
+ start = delimiter + 1
def _split_enum_case_values(enum_case_value):
- """Returns all enum cases in an enum case value.
+ """Returns all enum cases in an enum case value.
- Example: 'SHOUTY_CASE, kCamelCase' -> ['SHOUTY_CASE', 'kCamelCase']"""
- return [enum_case_value[start:end] for start, end
- in _split_enum_case_values_into_spans(enum_case_value)]
+ Example: 'SHOUTY_CASE, kCamelCase' -> ['SHOUTY_CASE', 'kCamelCase']"""
+ return [
+ enum_case_value[start:end]
+ for start, end in _split_enum_case_values_into_spans(enum_case_value)
+ ]
def _get_enum_value_names(enum_value):
- """Determines one or more enum names based on attributes"""
- cases = ["SHOUTY_CASE"]
- name = enum_value.name.name.text
- if enum_case := ir_util.get_attribute(enum_value.attribute,
- attributes.Attribute.ENUM_CASE):
- cases = _split_enum_case_values(enum_case.string_constant.text)
- return [name_conversion.convert_case("SHOUTY_CASE", case, name)
- for case in cases]
+ """Determines one or more enum names based on attributes"""
+ cases = ["SHOUTY_CASE"]
+ name = enum_value.name.name.text
+ if enum_case := ir_util.get_attribute(
+ enum_value.attribute, attributes.Attribute.ENUM_CASE
+ ):
+ cases = _split_enum_case_values(enum_case.string_constant.text)
+ return [name_conversion.convert_case("SHOUTY_CASE", case, name) for case in cases]
def _generate_enum_definition(type_ir, include_traits=True):
- """Generates C++ for an Emboss enum."""
- enum_values = []
- enum_from_string_statements = []
- string_from_enum_statements = []
- enum_is_known_statements = []
- previously_seen_numeric_values = set()
- max_bits = ir_util.get_integer_attribute(type_ir.attribute, "maximum_bits")
- is_signed = ir_util.get_boolean_attribute(type_ir.attribute, "is_signed")
- enum_type = _cpp_integer_type_for_enum(max_bits, is_signed)
- for value in type_ir.enumeration.value:
- numeric_value = ir_util.constant_value(value.value)
- enum_value_names = _get_enum_value_names(value)
+ """Generates C++ for an Emboss enum."""
+ enum_values = []
+ enum_from_string_statements = []
+ string_from_enum_statements = []
+ enum_is_known_statements = []
+ previously_seen_numeric_values = set()
+ max_bits = ir_util.get_integer_attribute(type_ir.attribute, "maximum_bits")
+ is_signed = ir_util.get_boolean_attribute(type_ir.attribute, "is_signed")
+ enum_type = _cpp_integer_type_for_enum(max_bits, is_signed)
+ for value in type_ir.enumeration.value:
+ numeric_value = ir_util.constant_value(value.value)
+ enum_value_names = _get_enum_value_names(value)
- for enum_value_name in enum_value_names:
- enum_values.append(
- code_template.format_template(_TEMPLATES.enum_value,
- name=enum_value_name,
- value=_render_integer(numeric_value)))
- if include_traits:
- enum_from_string_statements.append(
- code_template.format_template(_TEMPLATES.enum_from_name_case,
- enum=type_ir.name.name.text,
- value=enum_value_name,
- name=value.name.name.text))
+ for enum_value_name in enum_value_names:
+ enum_values.append(
+ code_template.format_template(
+ _TEMPLATES.enum_value,
+ name=enum_value_name,
+ value=_render_integer(numeric_value),
+ )
+ )
+ if include_traits:
+ enum_from_string_statements.append(
+ code_template.format_template(
+ _TEMPLATES.enum_from_name_case,
+ enum=type_ir.name.name.text,
+ value=enum_value_name,
+ name=value.name.name.text,
+ )
+ )
- if numeric_value not in previously_seen_numeric_values:
- string_from_enum_statements.append(
- code_template.format_template(_TEMPLATES.name_from_enum_case,
- enum=type_ir.name.name.text,
- value=enum_value_name,
- name=value.name.name.text))
+ if numeric_value not in previously_seen_numeric_values:
+ string_from_enum_statements.append(
+ code_template.format_template(
+ _TEMPLATES.name_from_enum_case,
+ enum=type_ir.name.name.text,
+ value=enum_value_name,
+ name=value.name.name.text,
+ )
+ )
- enum_is_known_statements.append(
- code_template.format_template(_TEMPLATES.enum_is_known_case,
- enum=type_ir.name.name.text,
- name=enum_value_name))
- previously_seen_numeric_values.add(numeric_value)
+ enum_is_known_statements.append(
+ code_template.format_template(
+ _TEMPLATES.enum_is_known_case,
+ enum=type_ir.name.name.text,
+ name=enum_value_name,
+ )
+ )
+ previously_seen_numeric_values.add(numeric_value)
- declaration = code_template.format_template(
- _TEMPLATES.enum_declaration,
- enum=type_ir.name.name.text,
- enum_type=enum_type)
- definition = code_template.format_template(
- _TEMPLATES.enum_definition,
- enum=type_ir.name.name.text,
- enum_type=enum_type,
- enum_values="".join(enum_values))
- if include_traits:
- definition += code_template.format_template(
- _TEMPLATES.enum_traits,
- enum=type_ir.name.name.text,
- enum_from_name_cases="\n".join(enum_from_string_statements),
- name_from_enum_cases="\n".join(string_from_enum_statements),
- enum_is_known_cases="\n".join(enum_is_known_statements))
+ declaration = code_template.format_template(
+ _TEMPLATES.enum_declaration, enum=type_ir.name.name.text, enum_type=enum_type
+ )
+ definition = code_template.format_template(
+ _TEMPLATES.enum_definition,
+ enum=type_ir.name.name.text,
+ enum_type=enum_type,
+ enum_values="".join(enum_values),
+ )
+ if include_traits:
+ definition += code_template.format_template(
+ _TEMPLATES.enum_traits,
+ enum=type_ir.name.name.text,
+ enum_from_name_cases="\n".join(enum_from_string_statements),
+ name_from_enum_cases="\n".join(string_from_enum_statements),
+ enum_is_known_cases="\n".join(enum_is_known_statements),
+ )
- return (declaration, definition, "")
+ return (declaration, definition, "")
def _generate_type_definition(type_ir, ir, config: Config):
- """Generates C++ for an Emboss type."""
- if type_ir.HasField("structure"):
- return _generate_structure_definition(type_ir, ir, config)
- elif type_ir.HasField("enumeration"):
- return _generate_enum_definition(type_ir, config.include_enum_traits)
- elif type_ir.HasField("external"):
- # TODO(bolms): This should probably generate an #include.
- return "", "", ""
- else:
- # TODO(bolms): provide error message instead of ICE
- assert False, "Unknown type {}".format(type_ir)
+ """Generates C++ for an Emboss type."""
+ if type_ir.HasField("structure"):
+ return _generate_structure_definition(type_ir, ir, config)
+ elif type_ir.HasField("enumeration"):
+ return _generate_enum_definition(type_ir, config.include_enum_traits)
+ elif type_ir.HasField("external"):
+ # TODO(bolms): This should probably generate an #include.
+ return "", "", ""
+ else:
+ # TODO(bolms): provide error message instead of ICE
+ assert False, "Unknown type {}".format(type_ir)
def _generate_header_guard(file_path):
- # TODO(bolms): Make this configurable.
- header_path = file_path + ".h"
- uppercased_path = header_path.upper()
- no_punctuation_path = re.sub(r"[^A-Za-z0-9_]", "_", uppercased_path)
- suffixed_path = no_punctuation_path + "_"
- no_double_underscore_path = re.sub(r"__+", "_", suffixed_path)
- return no_double_underscore_path
+ # TODO(bolms): Make this configurable.
+ header_path = file_path + ".h"
+ uppercased_path = header_path.upper()
+ no_punctuation_path = re.sub(r"[^A-Za-z0-9_]", "_", uppercased_path)
+ suffixed_path = no_punctuation_path + "_"
+ no_double_underscore_path = re.sub(r"__+", "_", suffixed_path)
+ return no_double_underscore_path
def _add_missing_enum_case_attribute_on_enum_value(enum_value, defaults):
- """Adds an `enum_case` attribute if there isn't one but a default is set."""
- if ir_util.get_attribute(enum_value.attribute,
- attributes.Attribute.ENUM_CASE) is None:
- if attributes.Attribute.ENUM_CASE in defaults:
- enum_value.attribute.extend([defaults[attributes.Attribute.ENUM_CASE]])
+ """Adds an `enum_case` attribute if there isn't one but a default is set."""
+ if (
+ ir_util.get_attribute(enum_value.attribute, attributes.Attribute.ENUM_CASE)
+ is None
+ ):
+ if attributes.Attribute.ENUM_CASE in defaults:
+ enum_value.attribute.extend([defaults[attributes.Attribute.ENUM_CASE]])
def _propagate_defaults(ir, targets, ancestors, add_fn):
- """Propagates default values
+ """Propagates default values
- Traverses the IR to propagate default values to target nodes.
+ Traverses the IR to propagate default values to target nodes.
- Arguments:
- targets: A list of target IR types to add attributes to.
- ancestors: Ancestor types which may contain the default values.
- add_fn: Function to add the attribute. May use any parameter available in
- fast_traverse_ir_top_down actions as well as `defaults` containing the
- default attributes set by ancestors.
+ Arguments:
+ targets: A list of target IR types to add attributes to.
+ ancestors: Ancestor types which may contain the default values.
+ add_fn: Function to add the attribute. May use any parameter available in
+ fast_traverse_ir_top_down actions as well as `defaults` containing the
+ default attributes set by ancestors.
- Returns:
- None
- """
- traverse_ir.fast_traverse_ir_top_down(
- ir, targets, add_fn,
- incidental_actions={
- ancestor: attribute_util.gather_default_attributes
- for ancestor in ancestors
- },
- parameters={"defaults": {}})
+ Returns:
+ None
+ """
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ targets,
+ add_fn,
+ incidental_actions={
+ ancestor: attribute_util.gather_default_attributes for ancestor in ancestors
+ },
+ parameters={"defaults": {}},
+ )
def _offset_source_location_column(source_location, offset):
- """Adds offsets from the start column of the supplied source location
+ """Adds offsets from the start column of the supplied source location
- Returns a new source location with all of the same properties as the provided
- source location, but with the columns modified by offsets from the original
- start column.
+ Returns a new source location with all of the same properties as the provided
+ source location, but with the columns modified by offsets from the original
+ start column.
- Offset should be a tuple of (start, end), which are the offsets relative to
- source_location.start.column to set the new start.column and end.column."""
+ Offset should be a tuple of (start, end), which are the offsets relative to
+ source_location.start.column to set the new start.column and end.column."""
- new_location = ir_data_utils.copy(source_location)
- new_location.start.column = source_location.start.column + offset[0]
- new_location.end.column = source_location.start.column + offset[1]
+ new_location = ir_data_utils.copy(source_location)
+ new_location.start.column = source_location.start.column + offset[0]
+ new_location.end.column = source_location.start.column + offset[1]
- return new_location
+ return new_location
def _verify_namespace_attribute(attr, source_file_name, errors):
- if attr.name.text != attributes.Attribute.NAMESPACE:
- return
- namespace_value = ir_data_utils.reader(attr).value.string_constant
- if not re.match(_NS_RE, namespace_value.text):
- if re.match(_NS_EMPTY_RE, namespace_value.text):
- errors.append([error.error(
- source_file_name, namespace_value.source_location,
- 'Empty namespace value is not allowed.')])
- elif re.match(_NS_GLOBAL_RE, namespace_value.text):
- errors.append([error.error(
- source_file_name, namespace_value.source_location,
- 'Global namespace is not allowed.')])
- else:
- errors.append([error.error(
- source_file_name, namespace_value.source_location,
- 'Invalid namespace, must be a valid C++ namespace, such as "abc", '
- '"abc::def", or "::abc::def::ghi" (ISO/IEC 14882:2017 '
- 'enclosing-namespace-specifier).')])
- return
- for word in _get_namespace_components(namespace_value.text):
- if word in _CPP_RESERVED_WORDS:
- errors.append([error.error(
- source_file_name, namespace_value.source_location,
- f'Reserved word "{word}" is not allowed as a namespace component.'
- )])
+ if attr.name.text != attributes.Attribute.NAMESPACE:
+ return
+ namespace_value = ir_data_utils.reader(attr).value.string_constant
+ if not re.match(_NS_RE, namespace_value.text):
+ if re.match(_NS_EMPTY_RE, namespace_value.text):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ namespace_value.source_location,
+ "Empty namespace value is not allowed.",
+ )
+ ]
+ )
+ elif re.match(_NS_GLOBAL_RE, namespace_value.text):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ namespace_value.source_location,
+ "Global namespace is not allowed.",
+ )
+ ]
+ )
+ else:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ namespace_value.source_location,
+ 'Invalid namespace, must be a valid C++ namespace, such as "abc", '
+ '"abc::def", or "::abc::def::ghi" (ISO/IEC 14882:2017 '
+ "enclosing-namespace-specifier).",
+ )
+ ]
+ )
+ return
+ for word in _get_namespace_components(namespace_value.text):
+ if word in _CPP_RESERVED_WORDS:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ namespace_value.source_location,
+ f'Reserved word "{word}" is not allowed as a namespace component.',
+ )
+ ]
+ )
def _verify_enum_case_attribute(attr, source_file_name, errors):
- """Verify that `enum_case` values are supported."""
- if attr.name.text != attributes.Attribute.ENUM_CASE:
- return
+ """Verify that `enum_case` values are supported."""
+ if attr.name.text != attributes.Attribute.ENUM_CASE:
+ return
- VALID_CASES = ', '.join(case for case in _SUPPORTED_ENUM_CASES)
- enum_case_value = attr.value.string_constant
- case_spans = _split_enum_case_values_into_spans(enum_case_value.text)
- seen_cases = set()
+ VALID_CASES = ", ".join(case for case in _SUPPORTED_ENUM_CASES)
+ enum_case_value = attr.value.string_constant
+ case_spans = _split_enum_case_values_into_spans(enum_case_value.text)
+ seen_cases = set()
- for start, end in case_spans:
- case_source_location = _offset_source_location_column(
- enum_case_value.source_location, (start, end))
- case = enum_case_value.text[start:end]
+ for start, end in case_spans:
+ case_source_location = _offset_source_location_column(
+ enum_case_value.source_location, (start, end)
+ )
+ case = enum_case_value.text[start:end]
- if start == end:
- errors.append([error.error(
- source_file_name, case_source_location,
- 'Empty enum case (or excess comma).')])
- continue
+ if start == end:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ case_source_location,
+ "Empty enum case (or excess comma).",
+ )
+ ]
+ )
+ continue
- if case in seen_cases:
- errors.append([error.error(
- source_file_name, case_source_location,
- f'Duplicate enum case "{case}".')])
- continue
- seen_cases.add(case)
+ if case in seen_cases:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ case_source_location,
+ f'Duplicate enum case "{case}".',
+ )
+ ]
+ )
+ continue
+ seen_cases.add(case)
- if case not in _SUPPORTED_ENUM_CASES:
- errors.append([error.error(
- source_file_name, case_source_location,
- f'Unsupported enum case "{case}", '
- f'supported cases are: {VALID_CASES}.')])
+ if case not in _SUPPORTED_ENUM_CASES:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ case_source_location,
+ f'Unsupported enum case "{case}", '
+ f"supported cases are: {VALID_CASES}.",
+ )
+ ]
+ )
def _verify_attribute_values(ir):
- """Verify backend attribute values."""
- errors = []
+ """Verify backend attribute values."""
+ errors = []
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Attribute], _verify_namespace_attribute,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Attribute], _verify_enum_case_attribute,
- parameters={"errors": errors})
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Attribute],
+ _verify_namespace_attribute,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Attribute],
+ _verify_enum_case_attribute,
+ parameters={"errors": errors},
+ )
- return errors
+ return errors
def _propagate_defaults_and_verify_attributes(ir):
- """Verify attributes and ensure defaults are set when not overridden.
+ """Verify attributes and ensure defaults are set when not overridden.
- Returns a list of errors if there are errors present, or an empty list if
- verification completed successfully."""
- if errors := attribute_util.check_attributes_in_ir(
- ir,
- back_end="cpp",
- types=attributes.TYPES,
- module_attributes=attributes.Scope.MODULE,
- struct_attributes=attributes.Scope.STRUCT,
- bits_attributes=attributes.Scope.BITS,
- enum_attributes=attributes.Scope.ENUM,
- enum_value_attributes=attributes.Scope.ENUM_VALUE):
- return errors
+ Returns a list of errors if there are errors present, or an empty list if
+ verification completed successfully."""
+ if errors := attribute_util.check_attributes_in_ir(
+ ir,
+ back_end="cpp",
+ types=attributes.TYPES,
+ module_attributes=attributes.Scope.MODULE,
+ struct_attributes=attributes.Scope.STRUCT,
+ bits_attributes=attributes.Scope.BITS,
+ enum_attributes=attributes.Scope.ENUM,
+ enum_value_attributes=attributes.Scope.ENUM_VALUE,
+ ):
+ return errors
- if errors := _verify_attribute_values(ir):
- return errors
+ if errors := _verify_attribute_values(ir):
+ return errors
- # Ensure defaults are set on EnumValues for `enum_case`.
- _propagate_defaults(
- ir,
- targets=[ir_data.EnumValue],
- ancestors=[ir_data.Module, ir_data.TypeDefinition],
- add_fn=_add_missing_enum_case_attribute_on_enum_value)
+ # Ensure defaults are set on EnumValues for `enum_case`.
+ _propagate_defaults(
+ ir,
+ targets=[ir_data.EnumValue],
+ ancestors=[ir_data.Module, ir_data.TypeDefinition],
+ add_fn=_add_missing_enum_case_attribute_on_enum_value,
+ )
- return []
+ return []
def generate_header(ir, config=Config()):
- """Generates a C++ header from an Emboss module.
+ """Generates a C++ header from an Emboss module.
- Arguments:
- ir: An EmbossIr of the module.
+ Arguments:
+ ir: An EmbossIr of the module.
- Returns:
- A tuple of (header, errors), where `header` is either a string containing
- the text of a C++ header which implements Views for the types in the Emboss
- module, or None, and `errors` is a possibly-empty list of error messages to
- display to the user.
- """
- errors = _propagate_defaults_and_verify_attributes(ir)
- if errors:
- return None, errors
- type_declarations = []
- type_definitions = []
- method_definitions = []
- for type_definition in ir.module[0].type:
- declaration, definition, methods = _generate_type_definition(
- type_definition, ir, config)
- type_declarations.append(declaration)
- type_definitions.append(definition)
- method_definitions.append(methods)
- body = code_template.format_template(
- _TEMPLATES.body,
- type_declarations="".join(type_declarations),
- type_definitions="".join(type_definitions),
- method_definitions="".join(method_definitions))
- body = _wrap_in_namespace(body, _get_module_namespace(ir.module[0]))
- includes = _get_includes(ir.module[0], config)
- return code_template.format_template(
- _TEMPLATES.outline,
- includes=includes,
- body=body,
- header_guard=_generate_header_guard(ir.module[0].source_file_name)), []
+ Returns:
+ A tuple of (header, errors), where `header` is either a string containing
+ the text of a C++ header which implements Views for the types in the Emboss
+ module, or None, and `errors` is a possibly-empty list of error messages to
+ display to the user.
+ """
+ errors = _propagate_defaults_and_verify_attributes(ir)
+ if errors:
+ return None, errors
+ type_declarations = []
+ type_definitions = []
+ method_definitions = []
+ for type_definition in ir.module[0].type:
+ declaration, definition, methods = _generate_type_definition(
+ type_definition, ir, config
+ )
+ type_declarations.append(declaration)
+ type_definitions.append(definition)
+ method_definitions.append(methods)
+ body = code_template.format_template(
+ _TEMPLATES.body,
+ type_declarations="".join(type_declarations),
+ type_definitions="".join(type_definitions),
+ method_definitions="".join(method_definitions),
+ )
+ body = _wrap_in_namespace(body, _get_module_namespace(ir.module[0]))
+ includes = _get_includes(ir.module[0], config)
+ return (
+ code_template.format_template(
+ _TEMPLATES.outline,
+ includes=includes,
+ body=body,
+ header_guard=_generate_header_guard(ir.module[0].source_file_name),
+ ),
+ [],
+ )
diff --git a/compiler/back_end/cpp/header_generator_test.py b/compiler/back_end/cpp/header_generator_test.py
index d58f798..6d31df8 100644
--- a/compiler/back_end/cpp/header_generator_test.py
+++ b/compiler/back_end/cpp/header_generator_test.py
@@ -22,359 +22,529 @@
from compiler.util import ir_data_utils
from compiler.util import test_util
+
def _make_ir_from_emb(emb_text, name="m.emb"):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- name,
- test_util.dict_file_reader({name: emb_text}))
- assert not errors
- return ir
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ name, test_util.dict_file_reader({name: emb_text})
+ )
+ assert not errors
+ return ir
class NormalizeIrTest(unittest.TestCase):
- def test_accepts_string_attribute(self):
- ir = _make_ir_from_emb('[(cpp) namespace: "foo"]\n')
- self.assertEqual([], header_generator.generate_header(ir)[1])
+ def test_accepts_string_attribute(self):
+ ir = _make_ir_from_emb('[(cpp) namespace: "foo"]\n')
+ self.assertEqual([], header_generator.generate_header(ir)[1])
- def test_rejects_wrong_type_for_string_attribute(self):
- ir = _make_ir_from_emb("[(cpp) namespace: 9]\n")
- attr = ir.module[0].attribute[0]
- self.assertEqual([[
- error.error("m.emb", attr.value.source_location,
- "Attribute '(cpp) namespace' must have a string value.")
- ]], header_generator.generate_header(ir)[1])
+ def test_rejects_wrong_type_for_string_attribute(self):
+ ir = _make_ir_from_emb("[(cpp) namespace: 9]\n")
+ attr = ir.module[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.value.source_location,
+ "Attribute '(cpp) namespace' must have a string value.",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- def test_rejects_emboss_internal_attribute_with_back_end_specifier(self):
- ir = _make_ir_from_emb('[(cpp) byte_order: "LittleEndian"]\n')
- attr = ir.module[0].attribute[0]
- self.assertEqual([[
- error.error("m.emb", attr.name.source_location,
- "Unknown attribute '(cpp) byte_order' on module 'm.emb'.")
- ]], header_generator.generate_header(ir)[1])
+ def test_rejects_emboss_internal_attribute_with_back_end_specifier(self):
+ ir = _make_ir_from_emb('[(cpp) byte_order: "LittleEndian"]\n')
+ attr = ir.module[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.name.source_location,
+ "Unknown attribute '(cpp) byte_order' on module 'm.emb'.",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- def test_accepts_enum_case(self):
- mod_ir = _make_ir_from_emb('[(cpp) $default enum_case: "kCamelCase"]')
- self.assertEqual([], header_generator.generate_header(mod_ir)[1])
- enum_ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: "kCamelCase"]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
- self.assertEqual([], header_generator.generate_header(enum_ir)[1])
- enum_value_ir = _make_ir_from_emb('enum Foo:\n'
- ' BAR = 1 [(cpp) enum_case: "kCamelCase"]\n'
- ' BAZ = 2\n'
- ' [(cpp) enum_case: "kCamelCase"]\n')
- self.assertEqual([], header_generator.generate_header(enum_value_ir)[1])
- enum_in_struct_ir = _make_ir_from_emb('struct Outer:\n'
- ' [(cpp) $default enum_case: "kCamelCase"]\n'
- ' enum Inner:\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
- self.assertEqual([], header_generator.generate_header(enum_in_struct_ir)[1])
- enum_in_bits_ir = _make_ir_from_emb('bits Outer:\n'
- ' [(cpp) $default enum_case: "kCamelCase"]\n'
- ' enum Inner:\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
- self.assertEqual([], header_generator.generate_header(enum_in_bits_ir)[1])
- enum_ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: "SHOUTY_CASE,"]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
- self.assertEqual([], header_generator.generate_header(enum_ir)[1])
- enum_ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: "SHOUTY_CASE ,kCamelCase"]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
- self.assertEqual([], header_generator.generate_header(enum_ir)[1])
+ def test_accepts_enum_case(self):
+ mod_ir = _make_ir_from_emb('[(cpp) $default enum_case: "kCamelCase"]')
+ self.assertEqual([], header_generator.generate_header(mod_ir)[1])
+ enum_ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: "kCamelCase"]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
+ self.assertEqual([], header_generator.generate_header(enum_ir)[1])
+ enum_value_ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' BAR = 1 [(cpp) enum_case: "kCamelCase"]\n'
+ " BAZ = 2\n"
+ ' [(cpp) enum_case: "kCamelCase"]\n'
+ )
+ self.assertEqual([], header_generator.generate_header(enum_value_ir)[1])
+ enum_in_struct_ir = _make_ir_from_emb(
+ "struct Outer:\n"
+ ' [(cpp) $default enum_case: "kCamelCase"]\n'
+ " enum Inner:\n"
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
+ self.assertEqual([], header_generator.generate_header(enum_in_struct_ir)[1])
+ enum_in_bits_ir = _make_ir_from_emb(
+ "bits Outer:\n"
+ ' [(cpp) $default enum_case: "kCamelCase"]\n'
+ " enum Inner:\n"
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
+ self.assertEqual([], header_generator.generate_header(enum_in_bits_ir)[1])
+ enum_ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: "SHOUTY_CASE,"]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
+ self.assertEqual([], header_generator.generate_header(enum_ir)[1])
+ enum_ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: "SHOUTY_CASE ,kCamelCase"]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
+ self.assertEqual([], header_generator.generate_header(enum_ir)[1])
- def test_rejects_bad_enum_case_at_start(self):
- ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: "SHORTY_CASE, kCamelCase"]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
- attr = ir.module[0].type[0].attribute[0]
+ def test_rejects_bad_enum_case_at_start(self):
+ ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: "SHORTY_CASE, kCamelCase"]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
+ attr = ir.module[0].type[0].attribute[0]
- bad_case_source_location = ir_data.Location()
- bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
- bad_case_source_location.CopyFrom(attr.value.source_location)
- # Location of SHORTY_CASE in the attribute line.
- bad_case_source_location.start.column = 30
- bad_case_source_location.end.column = 41
+ bad_case_source_location = ir_data.Location()
+ bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
+ bad_case_source_location.CopyFrom(attr.value.source_location)
+ # Location of SHORTY_CASE in the attribute line.
+ bad_case_source_location.start.column = 30
+ bad_case_source_location.end.column = 41
- self.assertEqual([[
- error.error("m.emb", bad_case_source_location,
- 'Unsupported enum case "SHORTY_CASE", '
- 'supported cases are: SHOUTY_CASE, kCamelCase.')
- ]], header_generator.generate_header(ir)[1])
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ bad_case_source_location,
+ 'Unsupported enum case "SHORTY_CASE", '
+ "supported cases are: SHOUTY_CASE, kCamelCase.",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- def test_rejects_bad_enum_case_in_middle(self):
- ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: "SHOUTY_CASE, bad_CASE, kCamelCase"]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
- attr = ir.module[0].type[0].attribute[0]
+ def test_rejects_bad_enum_case_in_middle(self):
+ ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: "SHOUTY_CASE, bad_CASE, kCamelCase"]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
+ attr = ir.module[0].type[0].attribute[0]
- bad_case_source_location = ir_data.Location()
- bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
- bad_case_source_location.CopyFrom(attr.value.source_location)
- # Location of bad_CASE in the attribute line.
- bad_case_source_location.start.column = 43
- bad_case_source_location.end.column = 51
+ bad_case_source_location = ir_data.Location()
+ bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
+ bad_case_source_location.CopyFrom(attr.value.source_location)
+ # Location of bad_CASE in the attribute line.
+ bad_case_source_location.start.column = 43
+ bad_case_source_location.end.column = 51
- self.assertEqual([[
- error.error("m.emb", bad_case_source_location,
- 'Unsupported enum case "bad_CASE", '
- 'supported cases are: SHOUTY_CASE, kCamelCase.')
- ]], header_generator.generate_header(ir)[1])
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ bad_case_source_location,
+ 'Unsupported enum case "bad_CASE", '
+ "supported cases are: SHOUTY_CASE, kCamelCase.",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- def test_rejects_bad_enum_case_at_end(self):
- ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: "SHOUTY_CASE, kCamelCase, BAD_case"]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
- attr = ir.module[0].type[0].attribute[0]
+ def test_rejects_bad_enum_case_at_end(self):
+ ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: "SHOUTY_CASE, kCamelCase, BAD_case"]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
+ attr = ir.module[0].type[0].attribute[0]
- bad_case_source_location = ir_data.Location()
- bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
- bad_case_source_location.CopyFrom(attr.value.source_location)
- # Location of BAD_case in the attribute line.
- bad_case_source_location.start.column = 55
- bad_case_source_location.end.column = 63
+ bad_case_source_location = ir_data.Location()
+ bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
+ bad_case_source_location.CopyFrom(attr.value.source_location)
+ # Location of BAD_case in the attribute line.
+ bad_case_source_location.start.column = 55
+ bad_case_source_location.end.column = 63
- self.assertEqual([[
- error.error("m.emb", bad_case_source_location,
- 'Unsupported enum case "BAD_case", '
- 'supported cases are: SHOUTY_CASE, kCamelCase.')
- ]], header_generator.generate_header(ir)[1])
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ bad_case_source_location,
+ 'Unsupported enum case "BAD_case", '
+ "supported cases are: SHOUTY_CASE, kCamelCase.",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- def test_rejects_duplicate_enum_case(self):
- ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: "SHOUTY_CASE, SHOUTY_CASE"]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
- attr = ir.module[0].type[0].attribute[0]
+ def test_rejects_duplicate_enum_case(self):
+ ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: "SHOUTY_CASE, SHOUTY_CASE"]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
+ attr = ir.module[0].type[0].attribute[0]
- bad_case_source_location = ir_data.Location()
- bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
- bad_case_source_location.CopyFrom(attr.value.source_location)
- # Location of the second SHOUTY_CASE in the attribute line.
- bad_case_source_location.start.column = 43
- bad_case_source_location.end.column = 54
+ bad_case_source_location = ir_data.Location()
+ bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
+ bad_case_source_location.CopyFrom(attr.value.source_location)
+ # Location of the second SHOUTY_CASE in the attribute line.
+ bad_case_source_location.start.column = 43
+ bad_case_source_location.end.column = 54
- self.assertEqual([[
- error.error("m.emb", bad_case_source_location,
- 'Duplicate enum case "SHOUTY_CASE".')
- ]], header_generator.generate_header(ir)[1])
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ bad_case_source_location,
+ 'Duplicate enum case "SHOUTY_CASE".',
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
+ def test_rejects_empty_enum_case(self):
+ # Double comma
+ ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: "SHOUTY_CASE,, kCamelCase"]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
+ attr = ir.module[0].type[0].attribute[0]
- def test_rejects_empty_enum_case(self):
- # Double comma
- ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: "SHOUTY_CASE,, kCamelCase"]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
- attr = ir.module[0].type[0].attribute[0]
+ bad_case_source_location = ir_data.Location()
+ bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
+ bad_case_source_location.CopyFrom(attr.value.source_location)
+ # Location of excess comma.
+ bad_case_source_location.start.column = 42
+ bad_case_source_location.end.column = 42
- bad_case_source_location = ir_data.Location()
- bad_case_source_location = ir_data_utils.builder(bad_case_source_location)
- bad_case_source_location.CopyFrom(attr.value.source_location)
- # Location of excess comma.
- bad_case_source_location.start.column = 42
- bad_case_source_location.end.column = 42
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ bad_case_source_location,
+ "Empty enum case (or excess comma).",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- self.assertEqual([[
- error.error("m.emb", bad_case_source_location,
- 'Empty enum case (or excess comma).')
- ]], header_generator.generate_header(ir)[1])
+ # Leading comma
+ ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: ", SHOUTY_CASE, kCamelCase"]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
- # Leading comma
- ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: ", SHOUTY_CASE, kCamelCase"]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
+ bad_case_source_location.start.column = 30
+ bad_case_source_location.end.column = 30
- bad_case_source_location.start.column = 30
- bad_case_source_location.end.column = 30
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ bad_case_source_location,
+ "Empty enum case (or excess comma).",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- self.assertEqual([[
- error.error("m.emb", bad_case_source_location,
- 'Empty enum case (or excess comma).')
- ]], header_generator.generate_header(ir)[1])
+ # Excess trailing comma
+ ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: "SHOUTY_CASE, kCamelCase,,"]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
- # Excess trailing comma
- ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: "SHOUTY_CASE, kCamelCase,,"]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
+ bad_case_source_location.start.column = 54
+ bad_case_source_location.end.column = 54
- bad_case_source_location.start.column = 54
- bad_case_source_location.end.column = 54
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ bad_case_source_location,
+ "Empty enum case (or excess comma).",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- self.assertEqual([[
- error.error("m.emb", bad_case_source_location,
- 'Empty enum case (or excess comma).')
- ]], header_generator.generate_header(ir)[1])
+ # Whitespace enum case
+ ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: "SHOUTY_CASE, , kCamelCase"]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
- # Whitespace enum case
- ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: "SHOUTY_CASE, , kCamelCase"]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
+ bad_case_source_location.start.column = 45
+ bad_case_source_location.end.column = 45
- bad_case_source_location.start.column = 45
- bad_case_source_location.end.column = 45
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ bad_case_source_location,
+ "Empty enum case (or excess comma).",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- self.assertEqual([[
- error.error("m.emb", bad_case_source_location,
- 'Empty enum case (or excess comma).')
- ]], header_generator.generate_header(ir)[1])
+ # Empty enum_case string
+ ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: ""]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
- # Empty enum_case string
- ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: ""]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
+ bad_case_source_location.start.column = 30
+ bad_case_source_location.end.column = 30
- bad_case_source_location.start.column = 30
- bad_case_source_location.end.column = 30
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ bad_case_source_location,
+ "Empty enum case (or excess comma).",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- self.assertEqual([[
- error.error("m.emb", bad_case_source_location,
- 'Empty enum case (or excess comma).')
- ]], header_generator.generate_header(ir)[1])
+ # Whitespace enum_case string
+ ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: " "]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
- # Whitespace enum_case string
- ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: " "]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
+ bad_case_source_location.start.column = 35
+ bad_case_source_location.end.column = 35
- bad_case_source_location.start.column = 35
- bad_case_source_location.end.column = 35
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ bad_case_source_location,
+ "Empty enum case (or excess comma).",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- self.assertEqual([[
- error.error("m.emb", bad_case_source_location,
- 'Empty enum case (or excess comma).')
- ]], header_generator.generate_header(ir)[1])
+ # One-character whitespace enum_case string
+ ir = _make_ir_from_emb(
+ "enum Foo:\n"
+ ' [(cpp) $default enum_case: " "]\n'
+ " BAR = 1\n"
+ " BAZ = 2\n"
+ )
- # One-character whitespace enum_case string
- ir = _make_ir_from_emb('enum Foo:\n'
- ' [(cpp) $default enum_case: " "]\n'
- ' BAR = 1\n'
- ' BAZ = 2\n')
+ bad_case_source_location.start.column = 31
+ bad_case_source_location.end.column = 31
- bad_case_source_location.start.column = 31
- bad_case_source_location.end.column = 31
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ bad_case_source_location,
+ "Empty enum case (or excess comma).",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- self.assertEqual([[
- error.error("m.emb", bad_case_source_location,
- 'Empty enum case (or excess comma).')
- ]], header_generator.generate_header(ir)[1])
+ def test_accepts_namespace(self):
+ for test in [
+ '[(cpp) namespace: "basic"]\n',
+ '[(cpp) namespace: "multiple::components"]\n',
+ '[(cpp) namespace: "::absolute"]\n',
+ '[(cpp) namespace: "::fully::qualified"]\n',
+ '[(cpp) namespace: "CAN::Be::cAPITAL"]\n',
+ '[(cpp) namespace: "trailingNumbers54321"]\n',
+ '[(cpp) namespace: "containing4321numbers"]\n',
+ '[(cpp) namespace: "can_have_underscores"]\n',
+ '[(cpp) namespace: "_initial_underscore"]\n',
+ '[(cpp) namespace: "_initial::_underscore"]\n',
+ '[(cpp) namespace: "::_initial::_underscore"]\n',
+ '[(cpp) namespace: "trailing_underscore_"]\n',
+ '[(cpp) namespace: "trailing_::underscore_"]\n',
+ '[(cpp) namespace: "::trailing_::underscore_"]\n',
+ '[(cpp) namespace: " spaces "]\n',
+ '[(cpp) namespace: "with :: spaces"]\n',
+ '[(cpp) namespace: " ::fully:: qualified :: with::spaces"]\n',
+ ]:
+ ir = _make_ir_from_emb(test)
+ self.assertEqual([], header_generator.generate_header(ir)[1])
- def test_accepts_namespace(self):
- for test in [
- '[(cpp) namespace: "basic"]\n',
- '[(cpp) namespace: "multiple::components"]\n',
- '[(cpp) namespace: "::absolute"]\n',
- '[(cpp) namespace: "::fully::qualified"]\n',
- '[(cpp) namespace: "CAN::Be::cAPITAL"]\n',
- '[(cpp) namespace: "trailingNumbers54321"]\n',
- '[(cpp) namespace: "containing4321numbers"]\n',
- '[(cpp) namespace: "can_have_underscores"]\n',
- '[(cpp) namespace: "_initial_underscore"]\n',
- '[(cpp) namespace: "_initial::_underscore"]\n',
- '[(cpp) namespace: "::_initial::_underscore"]\n',
- '[(cpp) namespace: "trailing_underscore_"]\n',
- '[(cpp) namespace: "trailing_::underscore_"]\n',
- '[(cpp) namespace: "::trailing_::underscore_"]\n',
- '[(cpp) namespace: " spaces "]\n',
- '[(cpp) namespace: "with :: spaces"]\n',
- '[(cpp) namespace: " ::fully:: qualified :: with::spaces"]\n',
- ]:
- ir = _make_ir_from_emb(test)
- self.assertEqual([], header_generator.generate_header(ir)[1])
+ def test_rejects_non_namespace_strings(self):
+ for test in [
+ '[(cpp) namespace: "5th::avenue"]\n',
+ '[(cpp) namespace: "can\'t::have::apostrophe"]\n',
+ '[(cpp) namespace: "cannot-have-dash"]\n',
+ '[(cpp) namespace: "no/slashes"]\n',
+ '[(cpp) namespace: "no\\\\slashes"]\n',
+ '[(cpp) namespace: "apostrophes*are*rejected"]\n',
+ '[(cpp) namespace: "avoid.dot"]\n',
+ '[(cpp) namespace: "does5+5"]\n',
+ '[(cpp) namespace: "=10"]\n',
+ '[(cpp) namespace: "?"]\n',
+ '[(cpp) namespace: "reject::spaces in::components"]\n',
+ '[(cpp) namespace: "totally::valid::but::extra +"]\n',
+ '[(cpp) namespace: "totally::valid::but::extra ::?"]\n',
+ '[(cpp) namespace: "< totally::valid::but::extra"]\n',
+ '[(cpp) namespace: "< ::totally::valid::but::extra"]\n',
+ '[(cpp) namespace: "::totally::valid::but::extra::"]\n',
+ '[(cpp) namespace: ":::extra::colon"]\n',
+ '[(cpp) namespace: "::extra:::colon"]\n',
+ ]:
+ ir = _make_ir_from_emb(test)
+ attr = ir.module[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.value.source_location,
+ "Invalid namespace, must be a valid C++ namespace, such "
+ 'as "abc", "abc::def", or "::abc::def::ghi" (ISO/IEC '
+ "14882:2017 enclosing-namespace-specifier).",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- def test_rejects_non_namespace_strings(self):
- for test in [
- '[(cpp) namespace: "5th::avenue"]\n',
- '[(cpp) namespace: "can\'t::have::apostrophe"]\n',
- '[(cpp) namespace: "cannot-have-dash"]\n',
- '[(cpp) namespace: "no/slashes"]\n',
- '[(cpp) namespace: "no\\\\slashes"]\n',
- '[(cpp) namespace: "apostrophes*are*rejected"]\n',
- '[(cpp) namespace: "avoid.dot"]\n',
- '[(cpp) namespace: "does5+5"]\n',
- '[(cpp) namespace: "=10"]\n',
- '[(cpp) namespace: "?"]\n',
- '[(cpp) namespace: "reject::spaces in::components"]\n',
- '[(cpp) namespace: "totally::valid::but::extra +"]\n',
- '[(cpp) namespace: "totally::valid::but::extra ::?"]\n',
- '[(cpp) namespace: "< totally::valid::but::extra"]\n',
- '[(cpp) namespace: "< ::totally::valid::but::extra"]\n',
- '[(cpp) namespace: "::totally::valid::but::extra::"]\n',
- '[(cpp) namespace: ":::extra::colon"]\n',
- '[(cpp) namespace: "::extra:::colon"]\n',
- ]:
- ir = _make_ir_from_emb(test)
- attr = ir.module[0].attribute[0]
- self.assertEqual([[
- error.error("m.emb", attr.value.source_location,
- 'Invalid namespace, must be a valid C++ namespace, such '
- 'as "abc", "abc::def", or "::abc::def::ghi" (ISO/IEC '
- '14882:2017 enclosing-namespace-specifier).')
- ]], header_generator.generate_header(ir)[1])
+ def test_rejects_empty_namespace(self):
+ for test in [
+ '[(cpp) namespace: ""]\n',
+ '[(cpp) namespace: " "]\n',
+ '[(cpp) namespace: " "]\n',
+ ]:
+ ir = _make_ir_from_emb(test)
+ attr = ir.module[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.value.source_location,
+ "Empty namespace value is not allowed.",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- def test_rejects_empty_namespace(self):
- for test in [
- '[(cpp) namespace: ""]\n',
- '[(cpp) namespace: " "]\n',
- '[(cpp) namespace: " "]\n',
- ]:
- ir = _make_ir_from_emb(test)
- attr = ir.module[0].attribute[0]
- self.assertEqual([[
- error.error("m.emb", attr.value.source_location,
- 'Empty namespace value is not allowed.')
- ]], header_generator.generate_header(ir)[1])
+ def test_rejects_global_namespace(self):
+ for test in [
+ '[(cpp) namespace: "::"]\n',
+ '[(cpp) namespace: " ::"]\n',
+ '[(cpp) namespace: ":: "]\n',
+ '[(cpp) namespace: " :: "]\n',
+ ]:
+ ir = _make_ir_from_emb(test)
+ attr = ir.module[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.value.source_location,
+ "Global namespace is not allowed.",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
- def test_rejects_global_namespace(self):
- for test in [
- '[(cpp) namespace: "::"]\n',
- '[(cpp) namespace: " ::"]\n',
- '[(cpp) namespace: ":: "]\n',
- '[(cpp) namespace: " :: "]\n',
- ]:
- ir = _make_ir_from_emb(test)
- attr = ir.module[0].attribute[0]
- self.assertEqual([[
- error.error("m.emb", attr.value.source_location,
- 'Global namespace is not allowed.')
- ]], header_generator.generate_header(ir)[1])
+ def test_rejects_reserved_namespace(self):
+ for test, expected in [
+ # Only component
+ ('[(cpp) namespace: "class"]\n', "class"),
+ # Only component, fully qualified name
+ ('[(cpp) namespace: "::const"]\n', "const"),
+ # First component
+ ('[(cpp) namespace: "if::valid"]\n', "if"),
+ # First component, fully qualified name
+ ('[(cpp) namespace: "::auto::pilot"]\n', "auto"),
+ # Last component
+ ('[(cpp) namespace: "make::do"]\n', "do"),
+ # Middle component
+ ('[(cpp) namespace: "our::new::product"]\n', "new"),
+ ]:
+ ir = _make_ir_from_emb(test)
+ attr = ir.module[0].attribute[0]
- def test_rejects_reserved_namespace(self):
- for test, expected in [
- # Only component
- ('[(cpp) namespace: "class"]\n', 'class'),
- # Only component, fully qualified name
- ('[(cpp) namespace: "::const"]\n', 'const'),
- # First component
- ('[(cpp) namespace: "if::valid"]\n', 'if'),
- # First component, fully qualified name
- ('[(cpp) namespace: "::auto::pilot"]\n', 'auto'),
- # Last component
- ('[(cpp) namespace: "make::do"]\n', 'do'),
- # Middle component
- ('[(cpp) namespace: "our::new::product"]\n', 'new'),
- ]:
- ir = _make_ir_from_emb(test)
- attr = ir.module[0].attribute[0]
-
- self.assertEqual([[
- error.error("m.emb", attr.value.source_location,
- f'Reserved word "{expected}" is not allowed '
- f'as a namespace component.')]],
- header_generator.generate_header(ir)[1])
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.value.source_location,
+ f'Reserved word "{expected}" is not allowed '
+ f"as a namespace component.",
+ )
+ ]
+ ],
+ header_generator.generate_header(ir)[1],
+ )
if __name__ == "__main__":
diff --git a/compiler/back_end/util/__init__.py b/compiler/back_end/util/__init__.py
index 2c31d84..086a24e 100644
--- a/compiler/back_end/util/__init__.py
+++ b/compiler/back_end/util/__init__.py
@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
diff --git a/compiler/back_end/util/code_template.py b/compiler/back_end/util/code_template.py
index 519fb99..e69c8c3 100644
--- a/compiler/back_end/util/code_template.py
+++ b/compiler/back_end/util/code_template.py
@@ -23,83 +23,83 @@
def format_template(template, **kwargs):
- """format_template acts like str.format, but uses ${name} instead of {name}.
+ """format_template acts like str.format, but uses ${name} instead of {name}.
- format_template acts like a str.format, except that instead of using { and }
- to delimit substitutions, format_template uses ${name}. This simplifies
- templates of source code in most languages, which frequently use "{" and "}",
- but very rarely use "$".
+ format_template acts like a str.format, except that instead of using { and }
+ to delimit substitutions, format_template uses ${name}. This simplifies
+ templates of source code in most languages, which frequently use "{" and "}",
+ but very rarely use "$".
- See the documentation for string.Template for details about
- template strings and the format of substitutions.
+ See the documentation for string.Template for details about
+ template strings and the format of substitutions.
- Arguments:
- template: A template to format.
- **kwargs: Keyword arguments for string.Template.substitute.
+ Arguments:
+ template: A template to format.
+ **kwargs: Keyword arguments for string.Template.substitute.
- Returns:
- A formatted string.
- """
- return template.substitute(**kwargs)
+ Returns:
+ A formatted string.
+ """
+ return template.substitute(**kwargs)
def parse_templates(text):
- """Parses text into a namedtuple of templates.
+ """Parses text into a namedtuple of templates.
- parse_templates will split its argument into templates by searching for lines
- of the form:
+ parse_templates will split its argument into templates by searching for lines
+ of the form:
- [punctuation] " ** " [name] " ** " [punctuation]
+ [punctuation] " ** " [name] " ** " [punctuation]
- e.g.:
+ e.g.:
- // ** struct_field_accessor ** ////////
+ // ** struct_field_accessor ** ////////
- Leading and trailing punctuation is ignored, and [name] is used as the name
- of the template. [name] should match [A-Za-z][A-Za-z0-9_]* -- that is, it
- should be a valid ASCII Python identifier.
+ Leading and trailing punctuation is ignored, and [name] is used as the name
+ of the template. [name] should match [A-Za-z][A-Za-z0-9_]* -- that is, it
+ should be a valid ASCII Python identifier.
- Additionally any `//` style comment without leading space of the form:
- ```C++
- // This is an emboss developer related comment, it's useful internally
- // but not relevant to end-users of generated code.
- ```
- will be stripped out of the generated code.
+ Additionally any `//` style comment without leading space of the form:
+ ```C++
+ // This is an emboss developer related comment, it's useful internally
+ // but not relevant to end-users of generated code.
+ ```
+ will be stripped out of the generated code.
- If a template wants to define a comment that will be included in the
- generated code a C-style comment is recommended:
- ```C++
- /** This will be included in the generated source. */
+ If a template wants to define a comment that will be included in the
+ generated code a C-style comment is recommended:
+ ```C++
+ /** This will be included in the generated source. */
- /**
- * So will this!
- */
- ```
+ /**
+ * So will this!
+ */
+ ```
- Arguments:
- text: The text to parse into templates.
+ Arguments:
+ text: The text to parse into templates.
- Returns:
- A namedtuple object whose attributes are the templates from text.
- """
- delimiter_re = re.compile(r"^\W*\*\* ([A-Za-z][A-Za-z0-9_]*) \*\*\W*$")
- comment_re = re.compile(r"^\s*//.*$")
- templates = {}
- name = None
- template = []
- def finish_template(template):
- return string.Template("\n".join(template))
+ Returns:
+ A namedtuple object whose attributes are the templates from text.
+ """
+ delimiter_re = re.compile(r"^\W*\*\* ([A-Za-z][A-Za-z0-9_]*) \*\*\W*$")
+ comment_re = re.compile(r"^\s*//.*$")
+ templates = {}
+ name = None
+ template = []
- for line in text.splitlines():
- if delimiter_re.match(line):
- if name:
+ def finish_template(template):
+ return string.Template("\n".join(template))
+
+ for line in text.splitlines():
+ if delimiter_re.match(line):
+ if name:
+ templates[name] = finish_template(template)
+ name = delimiter_re.match(line).group(1)
+ template = []
+ else:
+ if not comment_re.match(line):
+ template.append(line)
+ if name:
templates[name] = finish_template(template)
- name = delimiter_re.match(line).group(1)
- template = []
- else:
- if not comment_re.match(line):
- template.append(line)
- if name:
- templates[name] = finish_template(template)
- return collections.namedtuple("Templates",
- list(templates.keys()))(**templates)
+ return collections.namedtuple("Templates", list(templates.keys()))(**templates)
diff --git a/compiler/back_end/util/code_template_test.py b/compiler/back_end/util/code_template_test.py
index c133096..e4354fe 100644
--- a/compiler/back_end/util/code_template_test.py
+++ b/compiler/back_end/util/code_template_test.py
@@ -18,82 +18,88 @@
import unittest
from compiler.back_end.util import code_template
+
def _format_template_str(template: str, **kwargs) -> str:
- return code_template.format_template(string.Template(template), **kwargs)
+ return code_template.format_template(string.Template(template), **kwargs)
+
class FormatTest(unittest.TestCase):
- """Tests for code_template.format."""
+ """Tests for code_template.format."""
- def test_no_replacement_fields(self):
- self.assertEqual("foo", _format_template_str("foo"))
- self.assertEqual("{foo}", _format_template_str("{foo}"))
- self.assertEqual("${foo}", _format_template_str("$${foo}"))
+ def test_no_replacement_fields(self):
+ self.assertEqual("foo", _format_template_str("foo"))
+ self.assertEqual("{foo}", _format_template_str("{foo}"))
+ self.assertEqual("${foo}", _format_template_str("$${foo}"))
- def test_one_replacement_field(self):
- self.assertEqual("foo", _format_template_str("${bar}", bar="foo"))
- self.assertEqual("bazfoo",
- _format_template_str("baz${bar}", bar="foo"))
- self.assertEqual("foobaz",
- _format_template_str("${bar}baz", bar="foo"))
- self.assertEqual("bazfooqux",
- _format_template_str("baz${bar}qux", bar="foo"))
+ def test_one_replacement_field(self):
+ self.assertEqual("foo", _format_template_str("${bar}", bar="foo"))
+ self.assertEqual("bazfoo", _format_template_str("baz${bar}", bar="foo"))
+ self.assertEqual("foobaz", _format_template_str("${bar}baz", bar="foo"))
+ self.assertEqual("bazfooqux", _format_template_str("baz${bar}qux", bar="foo"))
- def test_one_replacement_field_with_formatting(self):
- # Basic string.Templates don't support formatting values.
- self.assertRaises(ValueError,
- _format_template_str, "${bar:.6f}", bar=1)
+ def test_one_replacement_field_with_formatting(self):
+ # Basic string.Templates don't support formatting values.
+ self.assertRaises(ValueError, _format_template_str, "${bar:.6f}", bar=1)
- def test_one_replacement_field_value_missing(self):
- self.assertRaises(KeyError, _format_template_str, "${bar}")
+ def test_one_replacement_field_value_missing(self):
+ self.assertRaises(KeyError, _format_template_str, "${bar}")
- def test_multiple_replacement_fields(self):
- self.assertEqual(" aaa bbb ",
- _format_template_str(" ${bar} ${baz} ",
- bar="aaa",
- baz="bbb"))
+ def test_multiple_replacement_fields(self):
+ self.assertEqual(
+ " aaa bbb ",
+ _format_template_str(" ${bar} ${baz} ", bar="aaa", baz="bbb"),
+ )
class ParseTemplatesTest(unittest.TestCase):
- """Tests for code_template.parse_templates."""
+ """Tests for code_template.parse_templates."""
- def assertTemplatesEqual(self, expected, actual): # pylint:disable=invalid-name
- """Compares the results of a parse_templates"""
- # Extract the name and template from the result tuple
- actual = {
- k: v.template for k, v in actual._asdict().items()
- }
- self.assertEqual(expected, actual)
+ def assertTemplatesEqual(self, expected, actual): # pylint:disable=invalid-name
+ """Compares the results of a parse_templates"""
+ # Extract the name and template from the result tuple
+ actual = {k: v.template for k, v in actual._asdict().items()}
+ self.assertEqual(expected, actual)
- def test_handles_no_template_case(self):
- self.assertTemplatesEqual({}, code_template.parse_templates(""))
- self.assertTemplatesEqual({}, code_template.parse_templates(
- "this is not a template"))
+ def test_handles_no_template_case(self):
+ self.assertTemplatesEqual({}, code_template.parse_templates(""))
+ self.assertTemplatesEqual(
+ {}, code_template.parse_templates("this is not a template")
+ )
- def test_handles_one_template_at_start(self):
- self.assertTemplatesEqual({"foo": "bar"},
- code_template.parse_templates("** foo **\nbar"))
+ def test_handles_one_template_at_start(self):
+ self.assertTemplatesEqual(
+ {"foo": "bar"}, code_template.parse_templates("** foo **\nbar")
+ )
- def test_handles_one_template_after_start(self):
- self.assertTemplatesEqual(
- {"foo": "bar"},
- code_template.parse_templates("text\n** foo **\nbar"))
+ def test_handles_one_template_after_start(self):
+ self.assertTemplatesEqual(
+ {"foo": "bar"}, code_template.parse_templates("text\n** foo **\nbar")
+ )
- def test_handles_delimiter_with_other_text(self):
- self.assertTemplatesEqual(
- {"foo": "bar"},
- code_template.parse_templates("text\n// ** foo ** ////\nbar"))
- self.assertTemplatesEqual(
- {"foo": "bar"},
- code_template.parse_templates("text\n# ** foo ** #####\nbar"))
+ def test_handles_delimiter_with_other_text(self):
+ self.assertTemplatesEqual(
+ {"foo": "bar"},
+ code_template.parse_templates("text\n// ** foo ** ////\nbar"),
+ )
+ self.assertTemplatesEqual(
+ {"foo": "bar"},
+ code_template.parse_templates("text\n# ** foo ** #####\nbar"),
+ )
- def test_handles_multiple_delimiters(self):
- self.assertTemplatesEqual({"foo": "bar",
- "baz": "qux"}, code_template.parse_templates(
- "** foo **\nbar\n** baz **\nqux"))
+ def test_handles_multiple_delimiters(self):
+ self.assertTemplatesEqual(
+ {"foo": "bar", "baz": "qux"},
+ code_template.parse_templates("** foo **\nbar\n** baz **\nqux"),
+ )
- def test_returns_object_with_attributes(self):
- self.assertEqual("bar", code_template.parse_templates(
- "** foo **\nbar\n** baz **\nqux").foo.template)
+ def test_returns_object_with_attributes(self):
+ self.assertEqual(
+ "bar",
+ code_template.parse_templates(
+ "** foo **\nbar\n** baz **\nqux"
+ ).foo.template,
+ )
+
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/__init__.py b/compiler/front_end/__init__.py
index 2c31d84..086a24e 100644
--- a/compiler/front_end/__init__.py
+++ b/compiler/front_end/__init__.py
@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
diff --git a/compiler/front_end/attribute_checker.py b/compiler/front_end/attribute_checker.py
index d8c637c..9e7fec2 100644
--- a/compiler/front_end/attribute_checker.py
+++ b/compiler/front_end/attribute_checker.py
@@ -38,22 +38,29 @@
# Attribute type checkers
_VALID_BYTE_ORDER = attribute_util.string_from_list(
- {"BigEndian", "LittleEndian", "Null"})
+ {"BigEndian", "LittleEndian", "Null"}
+)
_VALID_TEXT_OUTPUT = attribute_util.string_from_list({"Emit", "Skip"})
def _valid_back_ends(attr, module_source_file):
- if not re.match(
- r"^(?:\s*[a-z][a-z0-9_]*\s*(?:,\s*[a-z][a-z0-9_]*\s*)*,?)?\s*$",
- attr.value.string_constant.text):
- return [[error.error(
- module_source_file,
- attr.value.source_location,
- "Attribute '{name}' must be a comma-delimited list of back end "
- "specifiers (like \"cpp, proto\")), not \"{value}\".".format(
- name=attr.name.text,
- value=attr.value.string_constant.text))]]
- return []
+ if not re.match(
+ r"^(?:\s*[a-z][a-z0-9_]*\s*(?:,\s*[a-z][a-z0-9_]*\s*)*,?)?\s*$",
+ attr.value.string_constant.text,
+ ):
+ return [
+ [
+ error.error(
+ module_source_file,
+ attr.value.source_location,
+ "Attribute '{name}' must be a comma-delimited list of back end "
+ 'specifiers (like "cpp, proto")), not "{value}".'.format(
+ name=attr.name.text, value=attr.value.string_constant.text
+ ),
+ )
+ ]
+ ]
+ return []
# Attributes must be the same type no matter where they occur.
@@ -105,401 +112,539 @@
def _construct_integer_attribute(name, value, source_location):
- """Constructs an integer Attribute with the given name and value."""
- attr_value = ir_data.AttributeValue(
- expression=ir_data.Expression(
- constant=ir_data.NumericConstant(value=str(value),
- source_location=source_location),
- type=ir_data.ExpressionType(
- integer=ir_data.IntegerType(modular_value=str(value),
- modulus="infinity",
- minimum_value=str(value),
- maximum_value=str(value))),
- source_location=source_location),
- source_location=source_location)
- return ir_data.Attribute(name=ir_data.Word(text=name,
- source_location=source_location),
- value=attr_value,
- source_location=source_location)
+ """Constructs an integer Attribute with the given name and value."""
+ attr_value = ir_data.AttributeValue(
+ expression=ir_data.Expression(
+ constant=ir_data.NumericConstant(
+ value=str(value), source_location=source_location
+ ),
+ type=ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modular_value=str(value),
+ modulus="infinity",
+ minimum_value=str(value),
+ maximum_value=str(value),
+ )
+ ),
+ source_location=source_location,
+ ),
+ source_location=source_location,
+ )
+ return ir_data.Attribute(
+ name=ir_data.Word(text=name, source_location=source_location),
+ value=attr_value,
+ source_location=source_location,
+ )
def _construct_boolean_attribute(name, value, source_location):
- """Constructs a boolean Attribute with the given name and value."""
- attr_value = ir_data.AttributeValue(
- expression=ir_data.Expression(
- boolean_constant=ir_data.BooleanConstant(
- value=value, source_location=source_location),
- type=ir_data.ExpressionType(boolean=ir_data.BooleanType(value=value)),
- source_location=source_location),
- source_location=source_location)
- return ir_data.Attribute(name=ir_data.Word(text=name,
- source_location=source_location),
- value=attr_value,
- source_location=source_location)
+ """Constructs a boolean Attribute with the given name and value."""
+ attr_value = ir_data.AttributeValue(
+ expression=ir_data.Expression(
+ boolean_constant=ir_data.BooleanConstant(
+ value=value, source_location=source_location
+ ),
+ type=ir_data.ExpressionType(boolean=ir_data.BooleanType(value=value)),
+ source_location=source_location,
+ ),
+ source_location=source_location,
+ )
+ return ir_data.Attribute(
+ name=ir_data.Word(text=name, source_location=source_location),
+ value=attr_value,
+ source_location=source_location,
+ )
def _construct_string_attribute(name, value, source_location):
- """Constructs a string Attribute with the given name and value."""
- attr_value = ir_data.AttributeValue(
- string_constant=ir_data.String(text=value,
- source_location=source_location),
- source_location=source_location)
- return ir_data.Attribute(name=ir_data.Word(text=name,
- source_location=source_location),
- value=attr_value,
- source_location=source_location)
+ """Constructs a string Attribute with the given name and value."""
+ attr_value = ir_data.AttributeValue(
+ string_constant=ir_data.String(text=value, source_location=source_location),
+ source_location=source_location,
+ )
+ return ir_data.Attribute(
+ name=ir_data.Word(text=name, source_location=source_location),
+ value=attr_value,
+ source_location=source_location,
+ )
def _fixed_size_of_struct_or_bits(struct, unit_size):
- """Returns size of struct in bits or None, if struct is not fixed size."""
- size = 0
- for field in struct.field:
- if not field.HasField("location"):
- # Virtual fields do not contribute to the physical size of the struct.
- continue
- field_start = ir_util.constant_value(field.location.start)
- field_size = ir_util.constant_value(field.location.size)
- if field_start is None or field_size is None:
- # Technically, start + size could be constant even if start and size are
- # not; e.g. if start == x and size == 10 - x, but we don't handle that
- # here.
- return None
- # TODO(bolms): knows_own_size
- # TODO(bolms): compute min/max sizes for variable-sized arrays.
- field_end = field_start + field_size
- if field_end >= size:
- size = field_end
- return size * unit_size
+ """Returns size of struct in bits or None, if struct is not fixed size."""
+ size = 0
+ for field in struct.field:
+ if not field.HasField("location"):
+ # Virtual fields do not contribute to the physical size of the struct.
+ continue
+ field_start = ir_util.constant_value(field.location.start)
+ field_size = ir_util.constant_value(field.location.size)
+ if field_start is None or field_size is None:
+ # Technically, start + size could be constant even if start and size are
+ # not; e.g. if start == x and size == 10 - x, but we don't handle that
+ # here.
+ return None
+ # TODO(bolms): knows_own_size
+ # TODO(bolms): compute min/max sizes for variable-sized arrays.
+ field_end = field_start + field_size
+ if field_end >= size:
+ size = field_end
+ return size * unit_size
-def _verify_size_attributes_on_structure(struct, type_definition,
- source_file_name, errors):
- """Verifies size attributes on a struct or bits."""
- fixed_size = _fixed_size_of_struct_or_bits(struct,
- type_definition.addressable_unit)
- fixed_size_attr = ir_util.get_attribute(type_definition.attribute,
- attributes.FIXED_SIZE)
- if not fixed_size_attr:
- return
- if fixed_size is None:
- errors.append([error.error(
- source_file_name, fixed_size_attr.source_location,
- "Struct is marked as fixed size, but contains variable-location "
- "fields.")])
- elif ir_util.constant_value(fixed_size_attr.expression) != fixed_size:
- errors.append([error.error(
- source_file_name, fixed_size_attr.source_location,
- "Struct is {} bits, but is marked as {} bits.".format(
- fixed_size, ir_util.constant_value(fixed_size_attr.expression)))])
+def _verify_size_attributes_on_structure(
+ struct, type_definition, source_file_name, errors
+):
+ """Verifies size attributes on a struct or bits."""
+ fixed_size = _fixed_size_of_struct_or_bits(struct, type_definition.addressable_unit)
+ fixed_size_attr = ir_util.get_attribute(
+ type_definition.attribute, attributes.FIXED_SIZE
+ )
+ if not fixed_size_attr:
+ return
+ if fixed_size is None:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ fixed_size_attr.source_location,
+ "Struct is marked as fixed size, but contains variable-location "
+ "fields.",
+ )
+ ]
+ )
+ elif ir_util.constant_value(fixed_size_attr.expression) != fixed_size:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ fixed_size_attr.source_location,
+ "Struct is {} bits, but is marked as {} bits.".format(
+ fixed_size, ir_util.constant_value(fixed_size_attr.expression)
+ ),
+ )
+ ]
+ )
# TODO(bolms): remove [fixed_size]; it is superseded by $size_in_{bits,bytes}
def _add_missing_size_attributes_on_structure(struct, type_definition):
- """Adds missing size attributes on a struct."""
- fixed_size = _fixed_size_of_struct_or_bits(struct,
- type_definition.addressable_unit)
- if fixed_size is None:
- return
- fixed_size_attr = ir_util.get_attribute(type_definition.attribute,
- attributes.FIXED_SIZE)
- if not fixed_size_attr:
- # TODO(bolms): Use the offset and length of the last field as the
- # source_location of the fixed_size attribute?
- type_definition.attribute.extend([
- _construct_integer_attribute(attributes.FIXED_SIZE, fixed_size,
- type_definition.source_location)])
+ """Adds missing size attributes on a struct."""
+ fixed_size = _fixed_size_of_struct_or_bits(struct, type_definition.addressable_unit)
+ if fixed_size is None:
+ return
+ fixed_size_attr = ir_util.get_attribute(
+ type_definition.attribute, attributes.FIXED_SIZE
+ )
+ if not fixed_size_attr:
+ # TODO(bolms): Use the offset and length of the last field as the
+ # source_location of the fixed_size attribute?
+ type_definition.attribute.extend(
+ [
+ _construct_integer_attribute(
+ attributes.FIXED_SIZE, fixed_size, type_definition.source_location
+ )
+ ]
+ )
def _field_needs_byte_order(field, type_definition, ir):
- """Returns true if the given field needs a byte_order attribute."""
- if ir_util.field_is_virtual(field):
- # Virtual fields have no physical type, and thus do not need a byte order.
- return False
- field_type = ir_util.find_object(
- ir_util.get_base_type(field.type).atomic_type.reference.canonical_name,
- ir)
- assert field_type is not None
- assert field_type.addressable_unit != ir_data.AddressableUnit.NONE
- return field_type.addressable_unit != type_definition.addressable_unit
+ """Returns true if the given field needs a byte_order attribute."""
+ if ir_util.field_is_virtual(field):
+ # Virtual fields have no physical type, and thus do not need a byte order.
+ return False
+ field_type = ir_util.find_object(
+ ir_util.get_base_type(field.type).atomic_type.reference.canonical_name, ir
+ )
+ assert field_type is not None
+ assert field_type.addressable_unit != ir_data.AddressableUnit.NONE
+ return field_type.addressable_unit != type_definition.addressable_unit
def _field_may_have_null_byte_order(field, type_definition, ir):
- """Returns true if "Null" is a valid byte order for the given field."""
- # If the field is one unit in length, then byte order does not matter.
- if (ir_util.is_constant(field.location.size) and
- ir_util.constant_value(field.location.size) == 1):
- return True
- unit = type_definition.addressable_unit
- # Otherwise, if the field's type is either a one-unit-sized type or an array
- # of a one-unit-sized type, then byte order does not matter.
- if (ir_util.fixed_size_of_type_in_bits(ir_util.get_base_type(field.type), ir)
- == unit):
- return True
- # In all other cases, byte order does matter.
- return False
+ """Returns true if "Null" is a valid byte order for the given field."""
+ # If the field is one unit in length, then byte order does not matter.
+ if (
+ ir_util.is_constant(field.location.size)
+ and ir_util.constant_value(field.location.size) == 1
+ ):
+ return True
+ unit = type_definition.addressable_unit
+ # Otherwise, if the field's type is either a one-unit-sized type or an array
+ # of a one-unit-sized type, then byte order does not matter.
+ if (
+ ir_util.fixed_size_of_type_in_bits(ir_util.get_base_type(field.type), ir)
+ == unit
+ ):
+ return True
+ # In all other cases, byte order does matter.
+ return False
-def _add_missing_byte_order_attribute_on_field(field, type_definition, ir,
- defaults):
- """Adds missing byte_order attributes to fields that need them."""
- if _field_needs_byte_order(field, type_definition, ir):
- byte_order_attr = ir_util.get_attribute(field.attribute,
- attributes.BYTE_ORDER)
- if byte_order_attr is None:
- if attributes.BYTE_ORDER in defaults:
- field.attribute.extend([defaults[attributes.BYTE_ORDER]])
- elif _field_may_have_null_byte_order(field, type_definition, ir):
- field.attribute.extend(
- [_construct_string_attribute(attributes.BYTE_ORDER, "Null",
- field.source_location)])
+def _add_missing_byte_order_attribute_on_field(field, type_definition, ir, defaults):
+ """Adds missing byte_order attributes to fields that need them."""
+ if _field_needs_byte_order(field, type_definition, ir):
+ byte_order_attr = ir_util.get_attribute(field.attribute, attributes.BYTE_ORDER)
+ if byte_order_attr is None:
+ if attributes.BYTE_ORDER in defaults:
+ field.attribute.extend([defaults[attributes.BYTE_ORDER]])
+ elif _field_may_have_null_byte_order(field, type_definition, ir):
+ field.attribute.extend(
+ [
+ _construct_string_attribute(
+ attributes.BYTE_ORDER, "Null", field.source_location
+ )
+ ]
+ )
def _add_missing_back_ends_to_module(module):
- """Sets the expected_back_ends attribute for a module, if not already set."""
- back_ends_attr = ir_util.get_attribute(module.attribute, attributes.BACK_ENDS)
- if back_ends_attr is None:
- module.attribute.extend(
- [_construct_string_attribute(attributes.BACK_ENDS, _DEFAULT_BACK_ENDS,
- module.source_location)])
+ """Sets the expected_back_ends attribute for a module, if not already set."""
+ back_ends_attr = ir_util.get_attribute(module.attribute, attributes.BACK_ENDS)
+ if back_ends_attr is None:
+ module.attribute.extend(
+ [
+ _construct_string_attribute(
+ attributes.BACK_ENDS, _DEFAULT_BACK_ENDS, module.source_location
+ )
+ ]
+ )
def _gather_expected_back_ends(module):
- """Captures the expected_back_ends attribute for `module`."""
- back_ends_attr = ir_util.get_attribute(module.attribute, attributes.BACK_ENDS)
- back_ends_str = back_ends_attr.string_constant.text
- return {
- "expected_back_ends": {x.strip() for x in back_ends_str.split(",")} | {""}
- }
+ """Captures the expected_back_ends attribute for `module`."""
+ back_ends_attr = ir_util.get_attribute(module.attribute, attributes.BACK_ENDS)
+ back_ends_str = back_ends_attr.string_constant.text
+ return {"expected_back_ends": {x.strip() for x in back_ends_str.split(",")} | {""}}
def _add_addressable_unit_to_external(external, type_definition):
- """Sets the addressable_unit field for an external TypeDefinition."""
- # Strictly speaking, addressable_unit isn't an "attribute," but it's close
- # enough that it makes sense to handle it with attributes.
- del external # Unused.
- size = ir_util.get_integer_attribute(type_definition.attribute,
- attributes.ADDRESSABLE_UNIT_SIZE)
- if size == 1:
- type_definition.addressable_unit = ir_data.AddressableUnit.BIT
- elif size == 8:
- type_definition.addressable_unit = ir_data.AddressableUnit.BYTE
- # If the addressable_unit_size is not in (1, 8), it will be caught by
- # _verify_addressable_unit_attribute_on_external, below.
+ """Sets the addressable_unit field for an external TypeDefinition."""
+ # Strictly speaking, addressable_unit isn't an "attribute," but it's close
+ # enough that it makes sense to handle it with attributes.
+ del external # Unused.
+ size = ir_util.get_integer_attribute(
+ type_definition.attribute, attributes.ADDRESSABLE_UNIT_SIZE
+ )
+ if size == 1:
+ type_definition.addressable_unit = ir_data.AddressableUnit.BIT
+ elif size == 8:
+ type_definition.addressable_unit = ir_data.AddressableUnit.BYTE
+ # If the addressable_unit_size is not in (1, 8), it will be caught by
+ # _verify_addressable_unit_attribute_on_external, below.
def _add_missing_width_and_sign_attributes_on_enum(enum, type_definition):
- """Sets the maximum_bits and is_signed attributes for an enum, if needed."""
- max_bits_attr = ir_util.get_integer_attribute(type_definition.attribute,
- attributes.ENUM_MAXIMUM_BITS)
- if max_bits_attr is None:
- type_definition.attribute.extend([
- _construct_integer_attribute(attributes.ENUM_MAXIMUM_BITS,
- _DEFAULT_ENUM_MAXIMUM_BITS,
- type_definition.source_location)])
- signed_attr = ir_util.get_boolean_attribute(type_definition.attribute,
- attributes.IS_SIGNED)
- if signed_attr is None:
- for value in enum.value:
- numeric_value = ir_util.constant_value(value.value)
- if numeric_value < 0:
- is_signed = True
- break
- else:
- is_signed = False
- type_definition.attribute.extend([
- _construct_boolean_attribute(attributes.IS_SIGNED, is_signed,
- type_definition.source_location)])
+ """Sets the maximum_bits and is_signed attributes for an enum, if needed."""
+ max_bits_attr = ir_util.get_integer_attribute(
+ type_definition.attribute, attributes.ENUM_MAXIMUM_BITS
+ )
+ if max_bits_attr is None:
+ type_definition.attribute.extend(
+ [
+ _construct_integer_attribute(
+ attributes.ENUM_MAXIMUM_BITS,
+ _DEFAULT_ENUM_MAXIMUM_BITS,
+ type_definition.source_location,
+ )
+ ]
+ )
+ signed_attr = ir_util.get_boolean_attribute(
+ type_definition.attribute, attributes.IS_SIGNED
+ )
+ if signed_attr is None:
+ for value in enum.value:
+ numeric_value = ir_util.constant_value(value.value)
+ if numeric_value < 0:
+ is_signed = True
+ break
+ else:
+ is_signed = False
+ type_definition.attribute.extend(
+ [
+ _construct_boolean_attribute(
+ attributes.IS_SIGNED, is_signed, type_definition.source_location
+ )
+ ]
+ )
-def _verify_byte_order_attribute_on_field(field, type_definition,
- source_file_name, ir, errors):
- """Verifies the byte_order attribute on the given field."""
- byte_order_attr = ir_util.get_attribute(field.attribute,
- attributes.BYTE_ORDER)
- field_needs_byte_order = _field_needs_byte_order(field, type_definition, ir)
- if byte_order_attr and not field_needs_byte_order:
- errors.append([error.error(
- source_file_name, byte_order_attr.source_location,
- "Attribute 'byte_order' not allowed on field which is not byte order "
- "dependent.")])
- if not byte_order_attr and field_needs_byte_order:
- errors.append([error.error(
- source_file_name, field.source_location,
- "Attribute 'byte_order' required on field which is byte order "
- "dependent.")])
- if (byte_order_attr and byte_order_attr.string_constant.text == "Null" and
- not _field_may_have_null_byte_order(field, type_definition, ir)):
- errors.append([error.error(
- source_file_name, byte_order_attr.source_location,
- "Attribute 'byte_order' may only be 'Null' for one-byte fields.")])
+def _verify_byte_order_attribute_on_field(
+ field, type_definition, source_file_name, ir, errors
+):
+ """Verifies the byte_order attribute on the given field."""
+ byte_order_attr = ir_util.get_attribute(field.attribute, attributes.BYTE_ORDER)
+ field_needs_byte_order = _field_needs_byte_order(field, type_definition, ir)
+ if byte_order_attr and not field_needs_byte_order:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ byte_order_attr.source_location,
+ "Attribute 'byte_order' not allowed on field which is not byte order "
+ "dependent.",
+ )
+ ]
+ )
+ if not byte_order_attr and field_needs_byte_order:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ field.source_location,
+ "Attribute 'byte_order' required on field which is byte order "
+ "dependent.",
+ )
+ ]
+ )
+ if (
+ byte_order_attr
+ and byte_order_attr.string_constant.text == "Null"
+ and not _field_may_have_null_byte_order(field, type_definition, ir)
+ ):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ byte_order_attr.source_location,
+ "Attribute 'byte_order' may only be 'Null' for one-byte fields.",
+ )
+ ]
+ )
def _verify_requires_attribute_on_field(field, source_file_name, ir, errors):
- """Verifies that [requires] is valid on the given field."""
- requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
- if not requires_attr:
- return
- if ir_util.field_is_virtual(field):
- field_expression_type = field.read_transform.type
- else:
- if not field.type.HasField("atomic_type"):
- errors.append([
- error.error(source_file_name, requires_attr.source_location,
- "Attribute 'requires' is only allowed on integer, "
- "enumeration, or boolean fields, not arrays."),
- error.note(source_file_name, field.type.source_location,
- "Field type."),
- ])
- return
- field_type = ir_util.find_object(field.type.atomic_type.reference, ir)
- assert field_type, "Field type should be non-None after name resolution."
- field_expression_type = (
- type_check.unbounded_expression_type_for_physical_type(field_type))
- if field_expression_type.WhichOneof("type") not in (
- "integer", "enumeration", "boolean"):
- errors.append([error.error(
- source_file_name, requires_attr.source_location,
- "Attribute 'requires' is only allowed on integer, enumeration, or "
- "boolean fields.")])
+ """Verifies that [requires] is valid on the given field."""
+ requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
+ if not requires_attr:
+ return
+ if ir_util.field_is_virtual(field):
+ field_expression_type = field.read_transform.type
+ else:
+ if not field.type.HasField("atomic_type"):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ requires_attr.source_location,
+ "Attribute 'requires' is only allowed on integer, "
+ "enumeration, or boolean fields, not arrays.",
+ ),
+ error.note(
+ source_file_name, field.type.source_location, "Field type."
+ ),
+ ]
+ )
+ return
+ field_type = ir_util.find_object(field.type.atomic_type.reference, ir)
+ assert field_type, "Field type should be non-None after name resolution."
+ field_expression_type = type_check.unbounded_expression_type_for_physical_type(
+ field_type
+ )
+ if field_expression_type.WhichOneof("type") not in (
+ "integer",
+ "enumeration",
+ "boolean",
+ ):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ requires_attr.source_location,
+ "Attribute 'requires' is only allowed on integer, enumeration, or "
+ "boolean fields.",
+ )
+ ]
+ )
-def _verify_addressable_unit_attribute_on_external(external, type_definition,
- source_file_name, errors):
- """Verifies the addressable_unit_size attribute on an external."""
- del external # Unused.
- addressable_unit_size_attr = ir_util.get_integer_attribute(
- type_definition.attribute, attributes.ADDRESSABLE_UNIT_SIZE)
- if addressable_unit_size_attr is None:
- errors.append([error.error(
- source_file_name, type_definition.source_location,
- "Expected '{}' attribute for external type.".format(
- attributes.ADDRESSABLE_UNIT_SIZE))])
- elif addressable_unit_size_attr not in (1, 8):
- errors.append([
- error.error(source_file_name, type_definition.source_location,
+def _verify_addressable_unit_attribute_on_external(
+ external, type_definition, source_file_name, errors
+):
+ """Verifies the addressable_unit_size attribute on an external."""
+ del external # Unused.
+ addressable_unit_size_attr = ir_util.get_integer_attribute(
+ type_definition.attribute, attributes.ADDRESSABLE_UNIT_SIZE
+ )
+ if addressable_unit_size_attr is None:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ type_definition.source_location,
+ "Expected '{}' attribute for external type.".format(
+ attributes.ADDRESSABLE_UNIT_SIZE
+ ),
+ )
+ ]
+ )
+ elif addressable_unit_size_attr not in (1, 8):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ type_definition.source_location,
"Only values '1' (bit) and '8' (byte) are allowed for the "
- "'{}' attribute".format(attributes.ADDRESSABLE_UNIT_SIZE))
- ])
+ "'{}' attribute".format(attributes.ADDRESSABLE_UNIT_SIZE),
+ )
+ ]
+ )
-def _verify_width_attribute_on_enum(enum, type_definition, source_file_name,
- errors):
- """Verifies the maximum_bits attribute for an enum TypeDefinition."""
- max_bits_value = ir_util.get_integer_attribute(type_definition.attribute,
- attributes.ENUM_MAXIMUM_BITS)
- # The attribute should already have been defaulted, if not originally present.
- assert max_bits_value is not None, "maximum_bits not set"
- if max_bits_value > 64 or max_bits_value < 1:
- max_bits_attr = ir_util.get_attribute(type_definition.attribute,
- attributes.ENUM_MAXIMUM_BITS)
- errors.append([
- error.error(source_file_name, max_bits_attr.source_location,
- "'maximum_bits' on an 'enum' must be between 1 and 64.")
- ])
+def _verify_width_attribute_on_enum(enum, type_definition, source_file_name, errors):
+ """Verifies the maximum_bits attribute for an enum TypeDefinition."""
+ max_bits_value = ir_util.get_integer_attribute(
+ type_definition.attribute, attributes.ENUM_MAXIMUM_BITS
+ )
+ # The attribute should already have been defaulted, if not originally present.
+ assert max_bits_value is not None, "maximum_bits not set"
+ if max_bits_value > 64 or max_bits_value < 1:
+ max_bits_attr = ir_util.get_attribute(
+ type_definition.attribute, attributes.ENUM_MAXIMUM_BITS
+ )
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ max_bits_attr.source_location,
+ "'maximum_bits' on an 'enum' must be between 1 and 64.",
+ )
+ ]
+ )
def _add_missing_attributes_on_ir(ir):
- """Adds missing attributes in a complete IR."""
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Module], _add_missing_back_ends_to_module)
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.External], _add_addressable_unit_to_external)
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Enum], _add_missing_width_and_sign_attributes_on_enum)
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Structure], _add_missing_size_attributes_on_structure,
- incidental_actions={
- ir_data.Module: attribute_util.gather_default_attributes,
- ir_data.TypeDefinition: attribute_util.gather_default_attributes,
- ir_data.Field: attribute_util.gather_default_attributes,
- },
- parameters={"defaults": {}})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Field], _add_missing_byte_order_attribute_on_field,
- incidental_actions={
- ir_data.Module: attribute_util.gather_default_attributes,
- ir_data.TypeDefinition: attribute_util.gather_default_attributes,
- ir_data.Field: attribute_util.gather_default_attributes,
- },
- parameters={"defaults": {}})
- return []
+ """Adds missing attributes in a complete IR."""
+ traverse_ir.fast_traverse_ir_top_down(
+ ir, [ir_data.Module], _add_missing_back_ends_to_module
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir, [ir_data.External], _add_addressable_unit_to_external
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir, [ir_data.Enum], _add_missing_width_and_sign_attributes_on_enum
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Structure],
+ _add_missing_size_attributes_on_structure,
+ incidental_actions={
+ ir_data.Module: attribute_util.gather_default_attributes,
+ ir_data.TypeDefinition: attribute_util.gather_default_attributes,
+ ir_data.Field: attribute_util.gather_default_attributes,
+ },
+ parameters={"defaults": {}},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Field],
+ _add_missing_byte_order_attribute_on_field,
+ incidental_actions={
+ ir_data.Module: attribute_util.gather_default_attributes,
+ ir_data.TypeDefinition: attribute_util.gather_default_attributes,
+ ir_data.Field: attribute_util.gather_default_attributes,
+ },
+ parameters={"defaults": {}},
+ )
+ return []
-def _verify_field_attributes(field, type_definition, source_file_name, ir,
- errors):
- _verify_byte_order_attribute_on_field(field, type_definition,
- source_file_name, ir, errors)
- _verify_requires_attribute_on_field(field, source_file_name, ir, errors)
+def _verify_field_attributes(field, type_definition, source_file_name, ir, errors):
+ _verify_byte_order_attribute_on_field(
+ field, type_definition, source_file_name, ir, errors
+ )
+ _verify_requires_attribute_on_field(field, source_file_name, ir, errors)
-def _verify_back_end_attributes(attribute, expected_back_ends, source_file_name,
- ir, errors):
- back_end_text = ir_data_utils.reader(attribute).back_end.text
- if back_end_text not in expected_back_ends:
- expected_back_ends_for_error = expected_back_ends - {""}
- errors.append([error.error(
- source_file_name, attribute.back_end.source_location,
- "Back end specifier '{back_end}' does not match any expected back end "
- "specifier for this file: '{expected_back_ends}'. Add or update the "
- "'[expected_back_ends: \"{new_expected_back_ends}\"]' attribute at the "
- "file level if this back end specifier is intentional.".format(
- back_end=attribute.back_end.text,
- expected_back_ends="', '".join(
- sorted(expected_back_ends_for_error)),
- new_expected_back_ends=", ".join(
- sorted(expected_back_ends_for_error | {back_end_text})),
- ))])
+def _verify_back_end_attributes(
+ attribute, expected_back_ends, source_file_name, ir, errors
+):
+ back_end_text = ir_data_utils.reader(attribute).back_end.text
+ if back_end_text not in expected_back_ends:
+ expected_back_ends_for_error = expected_back_ends - {""}
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ attribute.back_end.source_location,
+ "Back end specifier '{back_end}' does not match any expected back end "
+ "specifier for this file: '{expected_back_ends}'. Add or update the "
+ "'[expected_back_ends: \"{new_expected_back_ends}\"]' attribute at the "
+ "file level if this back end specifier is intentional.".format(
+ back_end=attribute.back_end.text,
+ expected_back_ends="', '".join(
+ sorted(expected_back_ends_for_error)
+ ),
+ new_expected_back_ends=", ".join(
+ sorted(expected_back_ends_for_error | {back_end_text})
+ ),
+ ),
+ )
+ ]
+ )
def _verify_attributes_on_ir(ir):
- """Verifies attributes in a complete IR."""
- errors = []
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Attribute], _verify_back_end_attributes,
- incidental_actions={
- ir_data.Module: _gather_expected_back_ends,
- },
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Structure], _verify_size_attributes_on_structure,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Enum], _verify_width_attribute_on_enum,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.External], _verify_addressable_unit_attribute_on_external,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Field], _verify_field_attributes,
- parameters={"errors": errors})
- return errors
+ """Verifies attributes in a complete IR."""
+ errors = []
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Attribute],
+ _verify_back_end_attributes,
+ incidental_actions={
+ ir_data.Module: _gather_expected_back_ends,
+ },
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Structure],
+ _verify_size_attributes_on_structure,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Enum],
+ _verify_width_attribute_on_enum,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.External],
+ _verify_addressable_unit_attribute_on_external,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir, [ir_data.Field], _verify_field_attributes, parameters={"errors": errors}
+ )
+ return errors
def normalize_and_verify(ir):
- """Performs various normalizations and verifications on ir.
+ """Performs various normalizations and verifications on ir.
- Checks for duplicate attributes.
+ Checks for duplicate attributes.
- Adds fixed_size_in_bits and addressable_unit_size attributes to types when
- they are missing, and checks their correctness when they are not missing.
+ Adds fixed_size_in_bits and addressable_unit_size attributes to types when
+ they are missing, and checks their correctness when they are not missing.
- Arguments:
- ir: The IR object to normalize.
+ Arguments:
+ ir: The IR object to normalize.
- Returns:
- A list of validation errors, or an empty list if no errors were encountered.
- """
- errors = attribute_util.check_attributes_in_ir(
- ir,
- types=_ATTRIBUTE_TYPES,
- module_attributes=_MODULE_ATTRIBUTES,
- struct_attributes=_STRUCT_ATTRIBUTES,
- bits_attributes=_BITS_ATTRIBUTES,
- enum_attributes=_ENUM_ATTRIBUTES,
- external_attributes=_EXTERNAL_ATTRIBUTES,
- structure_virtual_field_attributes=_STRUCT_VIRTUAL_FIELD_ATTRIBUTES,
- structure_physical_field_attributes=_STRUCT_PHYSICAL_FIELD_ATTRIBUTES)
- if errors:
- return errors
- _add_missing_attributes_on_ir(ir)
- return _verify_attributes_on_ir(ir)
+ Returns:
+ A list of validation errors, or an empty list if no errors were encountered.
+ """
+ errors = attribute_util.check_attributes_in_ir(
+ ir,
+ types=_ATTRIBUTE_TYPES,
+ module_attributes=_MODULE_ATTRIBUTES,
+ struct_attributes=_STRUCT_ATTRIBUTES,
+ bits_attributes=_BITS_ATTRIBUTES,
+ enum_attributes=_ENUM_ATTRIBUTES,
+ external_attributes=_EXTERNAL_ATTRIBUTES,
+ structure_virtual_field_attributes=_STRUCT_VIRTUAL_FIELD_ATTRIBUTES,
+ structure_physical_field_attributes=_STRUCT_PHYSICAL_FIELD_ATTRIBUTES,
+ )
+ if errors:
+ return errors
+ _add_missing_attributes_on_ir(ir)
+ return _verify_attributes_on_ir(ir)
diff --git a/compiler/front_end/attribute_checker_test.py b/compiler/front_end/attribute_checker_test.py
index 4e7d8c7..4325ca4 100644
--- a/compiler/front_end/attribute_checker_test.py
+++ b/compiler/front_end/attribute_checker_test.py
@@ -31,667 +31,985 @@
def _make_ir_from_emb(emb_text, name="m.emb"):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- name,
- test_util.dict_file_reader({name: emb_text}),
- stop_before_step="normalize_and_verify")
- assert not errors
- return ir
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ name,
+ test_util.dict_file_reader({name: emb_text}),
+ stop_before_step="normalize_and_verify",
+ )
+ assert not errors
+ return ir
class NormalizeIrTest(unittest.TestCase):
- def test_rejects_may_be_used_as_integer(self):
- enum_ir = _make_ir_from_emb("enum Foo:\n"
- " [may_be_used_as_integer: false]\n"
- " VALUE = 1\n")
- enum_type_ir = enum_ir.module[0].type[0]
- self.assertEqual([[
- error.error(
- "m.emb", enum_type_ir.attribute[0].name.source_location,
- "Unknown attribute 'may_be_used_as_integer' on enum 'Foo'.")
- ]], attribute_checker.normalize_and_verify(enum_ir))
+ def test_rejects_may_be_used_as_integer(self):
+ enum_ir = _make_ir_from_emb(
+ "enum Foo:\n" " [may_be_used_as_integer: false]\n" " VALUE = 1\n"
+ )
+ enum_type_ir = enum_ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ enum_type_ir.attribute[0].name.source_location,
+ "Unknown attribute 'may_be_used_as_integer' on enum 'Foo'.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(enum_ir),
+ )
- def test_adds_fixed_size_attribute_to_struct(self):
- # field2 is intentionally after field3, in order to trigger certain code
- # paths in attribute_checker.py.
- struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+2] UInt field1\n"
- " 4 [+4] UInt field2\n"
- " 2 [+2] UInt field3\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
- size_attr = ir_util.get_attribute(struct_ir.module[0].type[0].attribute,
- _FIXED_SIZE)
- self.assertEqual(64, ir_util.constant_value(size_attr.expression))
- self.assertEqual(struct_ir.module[0].type[0].source_location,
- size_attr.source_location)
+ def test_adds_fixed_size_attribute_to_struct(self):
+ # field2 is intentionally after field3, in order to trigger certain code
+ # paths in attribute_checker.py.
+ struct_ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+2] UInt field1\n"
+ " 4 [+4] UInt field2\n"
+ " 2 [+2] UInt field3\n"
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
+ size_attr = ir_util.get_attribute(
+ struct_ir.module[0].type[0].attribute, _FIXED_SIZE
+ )
+ self.assertEqual(64, ir_util.constant_value(size_attr.expression))
+ self.assertEqual(
+ struct_ir.module[0].type[0].source_location, size_attr.source_location
+ )
- def test_adds_fixed_size_attribute_to_struct_with_virtual_field(self):
- struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+2] UInt field1\n"
- " let field2 = field1\n"
- " 2 [+2] UInt field3\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
- size_attr = ir_util.get_attribute(struct_ir.module[0].type[0].attribute,
- _FIXED_SIZE)
- self.assertEqual(32, ir_util.constant_value(size_attr.expression))
- self.assertEqual(struct_ir.module[0].type[0].source_location,
- size_attr.source_location)
+ def test_adds_fixed_size_attribute_to_struct_with_virtual_field(self):
+ struct_ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+2] UInt field1\n"
+ " let field2 = field1\n"
+ " 2 [+2] UInt field3\n"
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
+ size_attr = ir_util.get_attribute(
+ struct_ir.module[0].type[0].attribute, _FIXED_SIZE
+ )
+ self.assertEqual(32, ir_util.constant_value(size_attr.expression))
+ self.assertEqual(
+ struct_ir.module[0].type[0].source_location, size_attr.source_location
+ )
- def test_adds_fixed_size_attribute_to_anonymous_bits(self):
- struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+4] bits:\n"
- " 0 [+8] UInt field\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
- size_attr = ir_util.get_attribute(struct_ir.module[0].type[0].attribute,
- _FIXED_SIZE)
- self.assertEqual(32, ir_util.constant_value(size_attr.expression))
- bits_size_attr = ir_util.get_attribute(
- struct_ir.module[0].type[0].subtype[0].attribute, _FIXED_SIZE)
- self.assertEqual(8, ir_util.constant_value(bits_size_attr.expression))
- self.assertEqual(struct_ir.module[0].type[0].source_location,
- size_attr.source_location)
+ def test_adds_fixed_size_attribute_to_anonymous_bits(self):
+ struct_ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+4] bits:\n"
+ " 0 [+8] UInt field\n"
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
+ size_attr = ir_util.get_attribute(
+ struct_ir.module[0].type[0].attribute, _FIXED_SIZE
+ )
+ self.assertEqual(32, ir_util.constant_value(size_attr.expression))
+ bits_size_attr = ir_util.get_attribute(
+ struct_ir.module[0].type[0].subtype[0].attribute, _FIXED_SIZE
+ )
+ self.assertEqual(8, ir_util.constant_value(bits_size_attr.expression))
+ self.assertEqual(
+ struct_ir.module[0].type[0].source_location, size_attr.source_location
+ )
- def test_does_not_add_fixed_size_attribute_to_variable_size_struct(self):
- struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+4] UInt n\n"
- " 4 [+n] UInt:8[] payload\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
- self.assertIsNone(ir_util.get_attribute(
- struct_ir.module[0].type[0].attribute, _FIXED_SIZE))
+ def test_does_not_add_fixed_size_attribute_to_variable_size_struct(self):
+ struct_ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+4] UInt n\n"
+ " 4 [+n] UInt:8[] payload\n"
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
+ self.assertIsNone(
+ ir_util.get_attribute(struct_ir.module[0].type[0].attribute, _FIXED_SIZE)
+ )
- def test_accepts_correct_fixed_size_and_size_attributes_on_struct(self):
- struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " [fixed_size_in_bits: 64]\n"
- " 0 [+2] UInt field1\n"
- " 2 [+2] UInt field2\n"
- " 4 [+4] UInt field3\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
- size_attr = ir_util.get_attribute(struct_ir.module[0].type[0].attribute,
- _FIXED_SIZE)
- self.assertTrue(size_attr)
- self.assertEqual(64, ir_util.constant_value(size_attr.expression))
+ def test_accepts_correct_fixed_size_and_size_attributes_on_struct(self):
+ struct_ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " [fixed_size_in_bits: 64]\n"
+ " 0 [+2] UInt field1\n"
+ " 2 [+2] UInt field2\n"
+ " 4 [+4] UInt field3\n"
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
+ size_attr = ir_util.get_attribute(
+ struct_ir.module[0].type[0].attribute, _FIXED_SIZE
+ )
+ self.assertTrue(size_attr)
+ self.assertEqual(64, ir_util.constant_value(size_attr.expression))
- def test_accepts_correct_size_attribute_on_struct(self):
- struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " [fixed_size_in_bits: 64]\n"
- " 0 [+2] UInt field1\n"
- " 4 [+4] UInt field3\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
- size_attr = ir_util.get_attribute(struct_ir.module[0].type[0].attribute,
- _FIXED_SIZE)
- self.assertTrue(size_attr.expression)
- self.assertEqual(64, ir_util.constant_value(size_attr.expression))
+ def test_accepts_correct_size_attribute_on_struct(self):
+ struct_ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " [fixed_size_in_bits: 64]\n"
+ " 0 [+2] UInt field1\n"
+ " 4 [+4] UInt field3\n"
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(struct_ir))
+ size_attr = ir_util.get_attribute(
+ struct_ir.module[0].type[0].attribute, _FIXED_SIZE
+ )
+ self.assertTrue(size_attr.expression)
+ self.assertEqual(64, ir_util.constant_value(size_attr.expression))
- def test_rejects_incorrect_fixed_size_attribute_on_variable_size_struct(self):
- struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " [fixed_size_in_bits: 8]\n"
- " 0 [+4] UInt n\n"
- " 4 [+n] UInt:8[] payload\n")
- struct_type_ir = struct_ir.module[0].type[0]
- self.assertEqual([[error.error(
- "m.emb", struct_type_ir.attribute[0].value.source_location,
- "Struct is marked as fixed size, but contains variable-location "
- "fields.")]], attribute_checker.normalize_and_verify(struct_ir))
+ def test_rejects_incorrect_fixed_size_attribute_on_variable_size_struct(self):
+ struct_ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " [fixed_size_in_bits: 8]\n"
+ " 0 [+4] UInt n\n"
+ " 4 [+n] UInt:8[] payload\n"
+ )
+ struct_type_ir = struct_ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ struct_type_ir.attribute[0].value.source_location,
+ "Struct is marked as fixed size, but contains variable-location "
+ "fields.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(struct_ir),
+ )
- def test_rejects_size_attribute_with_wrong_large_value_on_struct(self):
- struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " [fixed_size_in_bits: 80]\n"
- " 0 [+2] UInt field1\n"
- " 2 [+2] UInt field2\n"
- " 4 [+4] UInt field3\n")
- struct_type_ir = struct_ir.module[0].type[0]
- self.assertEqual([
- [error.error("m.emb", struct_type_ir.attribute[0].value.source_location,
- "Struct is 64 bits, but is marked as 80 bits.")]
- ], attribute_checker.normalize_and_verify(struct_ir))
+ def test_rejects_size_attribute_with_wrong_large_value_on_struct(self):
+ struct_ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " [fixed_size_in_bits: 80]\n"
+ " 0 [+2] UInt field1\n"
+ " 2 [+2] UInt field2\n"
+ " 4 [+4] UInt field3\n"
+ )
+ struct_type_ir = struct_ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ struct_type_ir.attribute[0].value.source_location,
+ "Struct is 64 bits, but is marked as 80 bits.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(struct_ir),
+ )
- def test_rejects_size_attribute_with_wrong_small_value_on_struct(self):
- struct_ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " [fixed_size_in_bits: 40]\n"
- " 0 [+2] UInt field1\n"
- " 2 [+2] UInt field2\n"
- " 4 [+4] UInt field3\n")
- struct_type_ir = struct_ir.module[0].type[0]
- self.assertEqual([
- [error.error("m.emb", struct_type_ir.attribute[0].value.source_location,
- "Struct is 64 bits, but is marked as 40 bits.")]
- ], attribute_checker.normalize_and_verify(struct_ir))
+ def test_rejects_size_attribute_with_wrong_small_value_on_struct(self):
+ struct_ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " [fixed_size_in_bits: 40]\n"
+ " 0 [+2] UInt field1\n"
+ " 2 [+2] UInt field2\n"
+ " 4 [+4] UInt field3\n"
+ )
+ struct_type_ir = struct_ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ struct_type_ir.attribute[0].value.source_location,
+ "Struct is 64 bits, but is marked as 40 bits.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(struct_ir),
+ )
- def test_accepts_variable_size_external(self):
- external_ir = _make_ir_from_emb("external Foo:\n"
- " [addressable_unit_size: 1]\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(external_ir))
+ def test_accepts_variable_size_external(self):
+ external_ir = _make_ir_from_emb(
+ "external Foo:\n" " [addressable_unit_size: 1]\n"
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(external_ir))
- def test_accepts_fixed_size_external(self):
- external_ir = _make_ir_from_emb("external Foo:\n"
- " [fixed_size_in_bits: 32]\n"
- " [addressable_unit_size: 1]\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(external_ir))
+ def test_accepts_fixed_size_external(self):
+ external_ir = _make_ir_from_emb(
+ "external Foo:\n"
+ " [fixed_size_in_bits: 32]\n"
+ " [addressable_unit_size: 1]\n"
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(external_ir))
- def test_rejects_external_with_no_addressable_unit_size_attribute(self):
- external_ir = _make_ir_from_emb("external Foo:\n"
- " [is_integer: false]\n")
- external_type_ir = external_ir.module[0].type[0]
- self.assertEqual([
- [error.error(
- "m.emb", external_type_ir.source_location,
- "Expected 'addressable_unit_size' attribute for external type.")]
- ], attribute_checker.normalize_and_verify(external_ir))
+ def test_rejects_external_with_no_addressable_unit_size_attribute(self):
+ external_ir = _make_ir_from_emb("external Foo:\n" " [is_integer: false]\n")
+ external_type_ir = external_ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ external_type_ir.source_location,
+ "Expected 'addressable_unit_size' attribute for external type.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(external_ir),
+ )
- def test_rejects_is_integer_with_non_constant_value(self):
- external_ir = _make_ir_from_emb(
- "external Foo:\n"
- " [is_integer: $static_size_in_bits == 1]\n"
- " [addressable_unit_size: 1]\n")
- external_type_ir = external_ir.module[0].type[0]
- self.assertEqual([
- [error.error(
- "m.emb", external_type_ir.attribute[0].value.source_location,
- "Attribute 'is_integer' must have a constant boolean value.")]
- ], attribute_checker.normalize_and_verify(external_ir))
+ def test_rejects_is_integer_with_non_constant_value(self):
+ external_ir = _make_ir_from_emb(
+ "external Foo:\n"
+ " [is_integer: $static_size_in_bits == 1]\n"
+ " [addressable_unit_size: 1]\n"
+ )
+ external_type_ir = external_ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ external_type_ir.attribute[0].value.source_location,
+ "Attribute 'is_integer' must have a constant boolean value.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(external_ir),
+ )
- def test_rejects_addressable_unit_size_with_non_constant_value(self):
- external_ir = _make_ir_from_emb(
- "external Foo:\n"
- " [is_integer: true]\n"
- " [addressable_unit_size: $static_size_in_bits]\n")
- external_type_ir = external_ir.module[0].type[0]
- self.assertEqual([
- [error.error(
- "m.emb", external_type_ir.attribute[1].value.source_location,
- "Attribute 'addressable_unit_size' must have a constant value.")]
- ], attribute_checker.normalize_and_verify(external_ir))
+ def test_rejects_addressable_unit_size_with_non_constant_value(self):
+ external_ir = _make_ir_from_emb(
+ "external Foo:\n"
+ " [is_integer: true]\n"
+ " [addressable_unit_size: $static_size_in_bits]\n"
+ )
+ external_type_ir = external_ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ external_type_ir.attribute[1].value.source_location,
+ "Attribute 'addressable_unit_size' must have a constant value.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(external_ir),
+ )
- def test_rejects_external_with_wrong_addressable_unit_size_attribute(self):
- external_ir = _make_ir_from_emb("external Foo:\n"
- " [addressable_unit_size: 4]\n")
- external_type_ir = external_ir.module[0].type[0]
- self.assertEqual([
- [error.error(
- "m.emb", external_type_ir.source_location,
- "Only values '1' (bit) and '8' (byte) are allowed for the "
- "'addressable_unit_size' attribute")]
- ], attribute_checker.normalize_and_verify(external_ir))
+ def test_rejects_external_with_wrong_addressable_unit_size_attribute(self):
+ external_ir = _make_ir_from_emb(
+ "external Foo:\n" " [addressable_unit_size: 4]\n"
+ )
+ external_type_ir = external_ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ external_type_ir.source_location,
+ "Only values '1' (bit) and '8' (byte) are allowed for the "
+ "'addressable_unit_size' attribute",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(external_ir),
+ )
- def test_rejects_duplicate_attribute(self):
- ir = _make_ir_from_emb("external Foo:\n"
- " [is_integer: true]\n"
- " [is_integer: true]\n")
- self.assertEqual([[
- error.error("m.emb", ir.module[0].type[0].attribute[1].source_location,
- "Duplicate attribute 'is_integer'."),
- error.note("m.emb", ir.module[0].type[0].attribute[0].source_location,
- "Original attribute"),
- ]], attribute_checker.normalize_and_verify(ir))
+ def test_rejects_duplicate_attribute(self):
+ ir = _make_ir_from_emb(
+ "external Foo:\n" " [is_integer: true]\n" " [is_integer: true]\n"
+ )
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ ir.module[0].type[0].attribute[1].source_location,
+ "Duplicate attribute 'is_integer'.",
+ ),
+ error.note(
+ "m.emb",
+ ir.module[0].type[0].attribute[0].source_location,
+ "Original attribute",
+ ),
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_rejects_duplicate_default_attribute(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- '[$default byte_order: "LittleEndian"]\n')
- self.assertEqual(
- [[
- error.error("m.emb", ir.module[0].attribute[1].source_location,
- "Duplicate attribute 'byte_order'."),
- error.note("m.emb", ir.module[0].attribute[0].source_location,
- "Original attribute"),
- ]], attribute_checker.normalize_and_verify(ir))
+ def test_rejects_duplicate_default_attribute(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ '[$default byte_order: "LittleEndian"]\n'
+ )
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ ir.module[0].attribute[1].source_location,
+ "Duplicate attribute 'byte_order'.",
+ ),
+ error.note(
+ "m.emb",
+ ir.module[0].attribute[0].source_location,
+ "Original attribute",
+ ),
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_rejects_unknown_attribute(self):
- ir = _make_ir_from_emb("[gibberish: true]\n")
- attr = ir.module[0].attribute[0]
- self.assertEqual([[
- error.error("m.emb", attr.name.source_location,
- "Unknown attribute 'gibberish' on module 'm.emb'.")
- ]], attribute_checker.normalize_and_verify(ir))
+ def test_rejects_unknown_attribute(self):
+ ir = _make_ir_from_emb("[gibberish: true]\n")
+ attr = ir.module[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.name.source_location,
+ "Unknown attribute 'gibberish' on module 'm.emb'.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_rejects_non_constant_attribute(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " [fixed_size_in_bits: field1]\n"
- " 0 [+2] UInt field1\n")
- attr = ir.module[0].type[0].attribute[0]
- self.assertEqual(
- [[
- error.error(
- "m.emb", attr.value.source_location,
- "Attribute 'fixed_size_in_bits' must have a constant value.")
- ]],
- attribute_checker.normalize_and_verify(ir))
+ def test_rejects_non_constant_attribute(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " [fixed_size_in_bits: field1]\n"
+ " 0 [+2] UInt field1\n"
+ )
+ attr = ir.module[0].type[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.value.source_location,
+ "Attribute 'fixed_size_in_bits' must have a constant value.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_rejects_attribute_missing_required_back_end_specifier(self):
- ir = _make_ir_from_emb('[namespace: "abc"]\n')
- attr = ir.module[0].attribute[0]
- self.assertEqual([[
- error.error("m.emb", attr.name.source_location,
- "Unknown attribute 'namespace' on module 'm.emb'.")
- ]], attribute_checker.normalize_and_verify(ir))
+ def test_rejects_attribute_missing_required_back_end_specifier(self):
+ ir = _make_ir_from_emb('[namespace: "abc"]\n')
+ attr = ir.module[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.name.source_location,
+ "Unknown attribute 'namespace' on module 'm.emb'.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_accepts_attribute_with_default_known_back_end_specifier(self):
- ir = _make_ir_from_emb('[(cpp) namespace: "abc"]\n')
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ def test_accepts_attribute_with_default_known_back_end_specifier(self):
+ ir = _make_ir_from_emb('[(cpp) namespace: "abc"]\n')
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- def test_rejects_attribute_with_specified_back_end_specifier(self):
- ir = _make_ir_from_emb('[(c) namespace: "abc"]\n'
- '[expected_back_ends: "c, cpp"]\n')
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ def test_rejects_attribute_with_specified_back_end_specifier(self):
+ ir = _make_ir_from_emb(
+ '[(c) namespace: "abc"]\n' '[expected_back_ends: "c, cpp"]\n'
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- def test_rejects_cpp_backend_attribute_when_not_in_expected_back_ends(self):
- ir = _make_ir_from_emb('[(cpp) namespace: "abc"]\n'
- '[expected_back_ends: "c"]\n')
- attr = ir.module[0].attribute[0]
- self.maxDiff = 200000
- self.assertEqual([[
- error.error(
- "m.emb", attr.back_end.source_location,
- "Back end specifier 'cpp' does not match any expected back end "
- "specifier for this file: 'c'. Add or update the "
- "'[expected_back_ends: \"c, cpp\"]' attribute at the file level if "
- "this back end specifier is intentional.")
- ]], attribute_checker.normalize_and_verify(ir))
+ def test_rejects_cpp_backend_attribute_when_not_in_expected_back_ends(self):
+ ir = _make_ir_from_emb(
+ '[(cpp) namespace: "abc"]\n' '[expected_back_ends: "c"]\n'
+ )
+ attr = ir.module[0].attribute[0]
+ self.maxDiff = 200000
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.back_end.source_location,
+ "Back end specifier 'cpp' does not match any expected back end "
+ "specifier for this file: 'c'. Add or update the "
+ "'[expected_back_ends: \"c, cpp\"]' attribute at the file level if "
+ "this back end specifier is intentional.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_rejects_expected_back_ends_with_bad_back_end(self):
- ir = _make_ir_from_emb('[expected_back_ends: "c++"]\n')
- attr = ir.module[0].attribute[0]
- self.assertEqual([[
- error.error(
- "m.emb", attr.value.source_location,
- "Attribute 'expected_back_ends' must be a comma-delimited list of "
- "back end specifiers (like \"cpp, proto\")), not \"c++\".")
- ]], attribute_checker.normalize_and_verify(ir))
+ def test_rejects_expected_back_ends_with_bad_back_end(self):
+ ir = _make_ir_from_emb('[expected_back_ends: "c++"]\n')
+ attr = ir.module[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.value.source_location,
+ "Attribute 'expected_back_ends' must be a comma-delimited list of "
+ 'back end specifiers (like "cpp, proto")), not "c++".',
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_rejects_expected_back_ends_with_no_comma(self):
- ir = _make_ir_from_emb('[expected_back_ends: "cpp z"]\n')
- attr = ir.module[0].attribute[0]
- self.assertEqual([[
- error.error(
- "m.emb", attr.value.source_location,
- "Attribute 'expected_back_ends' must be a comma-delimited list of "
- "back end specifiers (like \"cpp, proto\")), not \"cpp z\".")
- ]], attribute_checker.normalize_and_verify(ir))
+ def test_rejects_expected_back_ends_with_no_comma(self):
+ ir = _make_ir_from_emb('[expected_back_ends: "cpp z"]\n')
+ attr = ir.module[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.value.source_location,
+ "Attribute 'expected_back_ends' must be a comma-delimited list of "
+ 'back end specifiers (like "cpp, proto")), not "cpp z".',
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_rejects_expected_back_ends_with_extra_commas(self):
- ir = _make_ir_from_emb('[expected_back_ends: "cpp,,z"]\n')
- attr = ir.module[0].attribute[0]
- self.assertEqual([[
- error.error(
- "m.emb", attr.value.source_location,
- "Attribute 'expected_back_ends' must be a comma-delimited list of "
- "back end specifiers (like \"cpp, proto\")), not \"cpp,,z\".")
- ]], attribute_checker.normalize_and_verify(ir))
+ def test_rejects_expected_back_ends_with_extra_commas(self):
+ ir = _make_ir_from_emb('[expected_back_ends: "cpp,,z"]\n')
+ attr = ir.module[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attr.value.source_location,
+ "Attribute 'expected_back_ends' must be a comma-delimited list of "
+ 'back end specifiers (like "cpp, proto")), not "cpp,,z".',
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_accepts_empty_expected_back_ends(self):
- ir = _make_ir_from_emb('[expected_back_ends: ""]\n')
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ def test_accepts_empty_expected_back_ends(self):
+ ir = _make_ir_from_emb('[expected_back_ends: ""]\n')
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- def test_adds_byte_order_attributes_from_default(self):
- ir = _make_ir_from_emb('[$default byte_order: "BigEndian"]\n'
- "struct Foo:\n"
- " 0 [+2] UInt bar\n"
- " 2 [+2] UInt baz\n"
- ' [byte_order: "LittleEndian"]\n')
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- byte_order_attr = ir_util.get_attribute(
- ir.module[0].type[0].structure.field[0].attribute, _BYTE_ORDER)
- self.assertTrue(byte_order_attr.HasField("string_constant"))
- self.assertEqual("BigEndian", byte_order_attr.string_constant.text)
- byte_order_attr = ir_util.get_attribute(
- ir.module[0].type[0].structure.field[1].attribute, _BYTE_ORDER)
- self.assertTrue(byte_order_attr.HasField("string_constant"))
- self.assertEqual("LittleEndian", byte_order_attr.string_constant.text)
+ def test_adds_byte_order_attributes_from_default(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "BigEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+2] UInt bar\n"
+ " 2 [+2] UInt baz\n"
+ ' [byte_order: "LittleEndian"]\n'
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ byte_order_attr = ir_util.get_attribute(
+ ir.module[0].type[0].structure.field[0].attribute, _BYTE_ORDER
+ )
+ self.assertTrue(byte_order_attr.HasField("string_constant"))
+ self.assertEqual("BigEndian", byte_order_attr.string_constant.text)
+ byte_order_attr = ir_util.get_attribute(
+ ir.module[0].type[0].structure.field[1].attribute, _BYTE_ORDER
+ )
+ self.assertTrue(byte_order_attr.HasField("string_constant"))
+ self.assertEqual("LittleEndian", byte_order_attr.string_constant.text)
- def test_adds_null_byte_order_attributes(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] UInt bar\n"
- " 1 [+1] UInt baz\n"
- ' [byte_order: "LittleEndian"]\n'
- " 2 [+2] UInt:8[] baseball\n"
- " 4 [+2] UInt:8[] bat\n"
- ' [byte_order: "LittleEndian"]\n')
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- structure = ir.module[0].type[0].structure
- byte_order_attr = ir_util.get_attribute(
- structure.field[0].attribute, _BYTE_ORDER)
- self.assertTrue(byte_order_attr.HasField("string_constant"))
- self.assertEqual("Null", byte_order_attr.string_constant.text)
- self.assertEqual(structure.field[0].source_location,
- byte_order_attr.source_location)
- byte_order_attr = ir_util.get_attribute(structure.field[1].attribute,
- _BYTE_ORDER)
- self.assertTrue(byte_order_attr.HasField("string_constant"))
- self.assertEqual("LittleEndian", byte_order_attr.string_constant.text)
- byte_order_attr = ir_util.get_attribute(structure.field[2].attribute,
- _BYTE_ORDER)
- self.assertTrue(byte_order_attr.HasField("string_constant"))
- self.assertEqual("Null", byte_order_attr.string_constant.text)
- self.assertEqual(structure.field[2].source_location,
- byte_order_attr.source_location)
- byte_order_attr = ir_util.get_attribute(structure.field[3].attribute,
- _BYTE_ORDER)
- self.assertTrue(byte_order_attr.HasField("string_constant"))
- self.assertEqual("LittleEndian", byte_order_attr.string_constant.text)
+ def test_adds_null_byte_order_attributes(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+1] UInt bar\n"
+ " 1 [+1] UInt baz\n"
+ ' [byte_order: "LittleEndian"]\n'
+ " 2 [+2] UInt:8[] baseball\n"
+ " 4 [+2] UInt:8[] bat\n"
+ ' [byte_order: "LittleEndian"]\n'
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ structure = ir.module[0].type[0].structure
+ byte_order_attr = ir_util.get_attribute(
+ structure.field[0].attribute, _BYTE_ORDER
+ )
+ self.assertTrue(byte_order_attr.HasField("string_constant"))
+ self.assertEqual("Null", byte_order_attr.string_constant.text)
+ self.assertEqual(
+ structure.field[0].source_location, byte_order_attr.source_location
+ )
+ byte_order_attr = ir_util.get_attribute(
+ structure.field[1].attribute, _BYTE_ORDER
+ )
+ self.assertTrue(byte_order_attr.HasField("string_constant"))
+ self.assertEqual("LittleEndian", byte_order_attr.string_constant.text)
+ byte_order_attr = ir_util.get_attribute(
+ structure.field[2].attribute, _BYTE_ORDER
+ )
+ self.assertTrue(byte_order_attr.HasField("string_constant"))
+ self.assertEqual("Null", byte_order_attr.string_constant.text)
+ self.assertEqual(
+ structure.field[2].source_location, byte_order_attr.source_location
+ )
+ byte_order_attr = ir_util.get_attribute(
+ structure.field[3].attribute, _BYTE_ORDER
+ )
+ self.assertTrue(byte_order_attr.HasField("string_constant"))
+ self.assertEqual("LittleEndian", byte_order_attr.string_constant.text)
- def test_disallows_default_byte_order_on_field(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+2] UInt bar\n"
- ' [$default byte_order: "LittleEndian"]\n')
- default_byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", default_byte_order.name.source_location,
- "Attribute 'byte_order' may not be defaulted on struct field 'bar'."
- )]],
- attribute_checker.normalize_and_verify(ir))
+ def test_disallows_default_byte_order_on_field(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+2] UInt bar\n"
+ ' [$default byte_order: "LittleEndian"]\n'
+ )
+ default_byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ default_byte_order.name.source_location,
+ "Attribute 'byte_order' may not be defaulted on struct field 'bar'.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_disallows_default_byte_order_on_bits(self):
- ir = _make_ir_from_emb("bits Foo:\n"
- ' [$default byte_order: "LittleEndian"]\n'
- " 0 [+2] UInt bar\n")
- default_byte_order = ir.module[0].type[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", default_byte_order.name.source_location,
- "Attribute 'byte_order' may not be defaulted on bits 'Foo'.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_disallows_default_byte_order_on_bits(self):
+ ir = _make_ir_from_emb(
+ "bits Foo:\n"
+ ' [$default byte_order: "LittleEndian"]\n'
+ " 0 [+2] UInt bar\n"
+ )
+ default_byte_order = ir.module[0].type[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ default_byte_order.name.source_location,
+ "Attribute 'byte_order' may not be defaulted on bits 'Foo'.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_disallows_default_byte_order_on_enum(self):
- ir = _make_ir_from_emb("enum Foo:\n"
- ' [$default byte_order: "LittleEndian"]\n'
- " BAR = 1\n")
- default_byte_order = ir.module[0].type[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", default_byte_order.name.source_location,
- "Attribute 'byte_order' may not be defaulted on enum 'Foo'.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_disallows_default_byte_order_on_enum(self):
+ ir = _make_ir_from_emb(
+ "enum Foo:\n" ' [$default byte_order: "LittleEndian"]\n' " BAR = 1\n"
+ )
+ default_byte_order = ir.module[0].type[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ default_byte_order.name.source_location,
+ "Attribute 'byte_order' may not be defaulted on enum 'Foo'.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_adds_byte_order_from_scoped_default(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- ' [$default byte_order: "BigEndian"]\n'
- " 0 [+2] UInt bar\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- byte_order_attr = ir_util.get_attribute(
- ir.module[0].type[0].structure.field[0].attribute, _BYTE_ORDER)
- self.assertTrue(byte_order_attr.HasField("string_constant"))
- self.assertEqual("BigEndian", byte_order_attr.string_constant.text)
+ def test_adds_byte_order_from_scoped_default(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ ' [$default byte_order: "BigEndian"]\n'
+ " 0 [+2] UInt bar\n"
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ byte_order_attr = ir_util.get_attribute(
+ ir.module[0].type[0].structure.field[0].attribute, _BYTE_ORDER
+ )
+ self.assertTrue(byte_order_attr.HasField("string_constant"))
+ self.assertEqual("BigEndian", byte_order_attr.string_constant.text)
- def test_disallows_unknown_byte_order(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+2] UInt bar\n"
- ' [byte_order: "NoEndian"]\n')
- byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", byte_order.value.source_location,
- "Attribute 'byte_order' must be 'BigEndian' or 'LittleEndian' or "
- "'Null'.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_disallows_unknown_byte_order(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " 0 [+2] UInt bar\n" ' [byte_order: "NoEndian"]\n'
+ )
+ byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ byte_order.value.source_location,
+ "Attribute 'byte_order' must be 'BigEndian' or 'LittleEndian' or "
+ "'Null'.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_disallows_unknown_default_byte_order(self):
- ir = _make_ir_from_emb('[$default byte_order: "NoEndian"]\n')
- default_byte_order = ir.module[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", default_byte_order.value.source_location,
- "Attribute 'byte_order' must be 'BigEndian' or 'LittleEndian' or "
- "'Null'.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_disallows_unknown_default_byte_order(self):
+ ir = _make_ir_from_emb('[$default byte_order: "NoEndian"]\n')
+ default_byte_order = ir.module[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ default_byte_order.value.source_location,
+ "Attribute 'byte_order' must be 'BigEndian' or 'LittleEndian' or "
+ "'Null'.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_disallows_byte_order_on_non_byte_order_dependent_fields(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- ' [$default byte_order: "LittleEndian"]\n'
- " 0 [+2] UInt uint\n"
- "struct Bar:\n"
- " 0 [+2] Foo foo\n"
- ' [byte_order: "LittleEndian"]\n')
- byte_order = ir.module[0].type[1].structure.field[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", byte_order.value.source_location,
- "Attribute 'byte_order' not allowed on field which is not byte "
- "order dependent.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_disallows_byte_order_on_non_byte_order_dependent_fields(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ ' [$default byte_order: "LittleEndian"]\n'
+ " 0 [+2] UInt uint\n"
+ "struct Bar:\n"
+ " 0 [+2] Foo foo\n"
+ ' [byte_order: "LittleEndian"]\n'
+ )
+ byte_order = ir.module[0].type[1].structure.field[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ byte_order.value.source_location,
+ "Attribute 'byte_order' not allowed on field which is not byte "
+ "order dependent.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_disallows_byte_order_on_virtual_field(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " let x = 10\n"
- ' [byte_order: "LittleEndian"]\n')
- byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", byte_order.name.source_location,
- "Unknown attribute 'byte_order' on virtual struct field 'x'.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_disallows_byte_order_on_virtual_field(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " let x = 10\n" ' [byte_order: "LittleEndian"]\n'
+ )
+ byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ byte_order.name.source_location,
+ "Unknown attribute 'byte_order' on virtual struct field 'x'.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_disallows_null_byte_order_on_multibyte_fields(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+2] UInt uint\n"
- ' [byte_order: "Null"]\n')
- byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", byte_order.value.source_location,
- "Attribute 'byte_order' may only be 'Null' for one-byte fields.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_disallows_null_byte_order_on_multibyte_fields(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " 0 [+2] UInt uint\n" ' [byte_order: "Null"]\n'
+ )
+ byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ byte_order.value.source_location,
+ "Attribute 'byte_order' may only be 'Null' for one-byte fields.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_disallows_null_byte_order_on_multibyte_array_elements(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+4] UInt:16[] uint\n"
- ' [byte_order: "Null"]\n')
- byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", byte_order.value.source_location,
- "Attribute 'byte_order' may only be 'Null' for one-byte fields.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_disallows_null_byte_order_on_multibyte_array_elements(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " 0 [+4] UInt:16[] uint\n" ' [byte_order: "Null"]\n'
+ )
+ byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ byte_order.value.source_location,
+ "Attribute 'byte_order' may only be 'Null' for one-byte fields.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_requires_byte_order_on_byte_order_dependent_fields(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+2] UInt uint\n")
- field = ir.module[0].type[0].structure.field[0]
- self.assertEqual(
- [[error.error(
- "m.emb", field.source_location,
- "Attribute 'byte_order' required on field which is byte order "
- "dependent.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_requires_byte_order_on_byte_order_dependent_fields(self):
+ ir = _make_ir_from_emb("struct Foo:\n" " 0 [+2] UInt uint\n")
+ field = ir.module[0].type[0].structure.field[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ field.source_location,
+ "Attribute 'byte_order' required on field which is byte order "
+ "dependent.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_disallows_unknown_text_output_attribute(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+2] UInt bar\n"
- ' [text_output: "None"]\n')
- byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", byte_order.value.source_location,
- "Attribute 'text_output' must be 'Emit' or 'Skip'.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_disallows_unknown_text_output_attribute(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " 0 [+2] UInt bar\n" ' [text_output: "None"]\n'
+ )
+ byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ byte_order.value.source_location,
+ "Attribute 'text_output' must be 'Emit' or 'Skip'.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_disallows_non_string_text_output_attribute(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+2] UInt bar\n"
- " [text_output: 0]\n")
- byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", byte_order.value.source_location,
- "Attribute 'text_output' must be 'Emit' or 'Skip'.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_disallows_non_string_text_output_attribute(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " 0 [+2] UInt bar\n" " [text_output: 0]\n"
+ )
+ byte_order = ir.module[0].type[0].structure.field[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ byte_order.value.source_location,
+ "Attribute 'text_output' must be 'Emit' or 'Skip'.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_allows_skip_text_output_attribute_on_physical_field(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] UInt bar\n"
- ' [text_output: "Skip"]\n')
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ def test_allows_skip_text_output_attribute_on_physical_field(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " 0 [+1] UInt bar\n" ' [text_output: "Skip"]\n'
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- def test_allows_skip_text_output_attribute_on_virtual_field(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " let x = 10\n"
- ' [text_output: "Skip"]\n')
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ def test_allows_skip_text_output_attribute_on_virtual_field(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " let x = 10\n" ' [text_output: "Skip"]\n'
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- def test_allows_emit_text_output_attribute_on_physical_field(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] UInt bar\n"
- ' [text_output: "Emit"]\n')
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ def test_allows_emit_text_output_attribute_on_physical_field(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " 0 [+1] UInt bar\n" ' [text_output: "Emit"]\n'
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- def test_adds_bit_addressable_unit_to_external(self):
- external_ir = _make_ir_from_emb("external Foo:\n"
- " [addressable_unit_size: 1]\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(external_ir))
- self.assertEqual(ir_data.AddressableUnit.BIT,
- external_ir.module[0].type[0].addressable_unit)
+ def test_adds_bit_addressable_unit_to_external(self):
+ external_ir = _make_ir_from_emb(
+ "external Foo:\n" " [addressable_unit_size: 1]\n"
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(external_ir))
+ self.assertEqual(
+ ir_data.AddressableUnit.BIT, external_ir.module[0].type[0].addressable_unit
+ )
- def test_adds_byte_addressable_unit_to_external(self):
- external_ir = _make_ir_from_emb("external Foo:\n"
- " [addressable_unit_size: 8]\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(external_ir))
- self.assertEqual(ir_data.AddressableUnit.BYTE,
- external_ir.module[0].type[0].addressable_unit)
+ def test_adds_byte_addressable_unit_to_external(self):
+ external_ir = _make_ir_from_emb(
+ "external Foo:\n" " [addressable_unit_size: 8]\n"
+ )
+ self.assertEqual([], attribute_checker.normalize_and_verify(external_ir))
+ self.assertEqual(
+ ir_data.AddressableUnit.BYTE, external_ir.module[0].type[0].addressable_unit
+ )
- def test_rejects_requires_using_array(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+4] UInt:8[] array\n"
- " [requires: this]\n")
- field_ir = ir.module[0].type[0].structure.field[0]
- self.assertEqual(
- [[error.error("m.emb", field_ir.attribute[0].value.source_location,
- "Attribute 'requires' must have a boolean value.")]],
- attribute_checker.normalize_and_verify(ir))
+ def test_rejects_requires_using_array(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " 0 [+4] UInt:8[] array\n" " [requires: this]\n"
+ )
+ field_ir = ir.module[0].type[0].structure.field[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ field_ir.attribute[0].value.source_location,
+ "Attribute 'requires' must have a boolean value.",
+ )
+ ]
+ ],
+ attribute_checker.normalize_and_verify(ir),
+ )
- def test_rejects_requires_on_array(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+4] UInt:8[] array\n"
- " [requires: false]\n")
- field_ir = ir.module[0].type[0].structure.field[0]
- self.assertEqual(
- [[
- error.error("m.emb", field_ir.attribute[0].value.source_location,
+ def test_rejects_requires_on_array(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " 0 [+4] UInt:8[] array\n" " [requires: false]\n"
+ )
+ field_ir = ir.module[0].type[0].structure.field[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ field_ir.attribute[0].value.source_location,
"Attribute 'requires' is only allowed on integer, "
- "enumeration, or boolean fields, not arrays."),
- error.note("m.emb", field_ir.type.source_location,
- "Field type."),
- ]],
- error.filter_errors(attribute_checker.normalize_and_verify(ir)))
+ "enumeration, or boolean fields, not arrays.",
+ ),
+ error.note("m.emb", field_ir.type.source_location, "Field type."),
+ ]
+ ],
+ error.filter_errors(attribute_checker.normalize_and_verify(ir)),
+ )
- def test_rejects_requires_on_struct(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+4] Bar bar\n"
- " [requires: false]\n"
- "struct Bar:\n"
- " 0 [+4] UInt uint\n")
- field_ir = ir.module[0].type[0].structure.field[0]
- self.assertEqual(
- [[error.error("m.emb", field_ir.attribute[0].value.source_location,
- "Attribute 'requires' is only allowed on integer, "
- "enumeration, or boolean fields.")]],
- error.filter_errors(attribute_checker.normalize_and_verify(ir)))
+ def test_rejects_requires_on_struct(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+4] Bar bar\n"
+ " [requires: false]\n"
+ "struct Bar:\n"
+ " 0 [+4] UInt uint\n"
+ )
+ field_ir = ir.module[0].type[0].structure.field[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ field_ir.attribute[0].value.source_location,
+ "Attribute 'requires' is only allowed on integer, "
+ "enumeration, or boolean fields.",
+ )
+ ]
+ ],
+ error.filter_errors(attribute_checker.normalize_and_verify(ir)),
+ )
- def test_rejects_requires_on_float(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+4] Float float\n"
- " [requires: false]\n")
- field_ir = ir.module[0].type[0].structure.field[0]
- self.assertEqual(
- [[error.error("m.emb", field_ir.attribute[0].value.source_location,
- "Attribute 'requires' is only allowed on integer, "
- "enumeration, or boolean fields.")]],
- error.filter_errors(attribute_checker.normalize_and_verify(ir)))
+ def test_rejects_requires_on_float(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+4] Float float\n"
+ " [requires: false]\n"
+ )
+ field_ir = ir.module[0].type[0].structure.field[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ field_ir.attribute[0].value.source_location,
+ "Attribute 'requires' is only allowed on integer, "
+ "enumeration, or boolean fields.",
+ )
+ ]
+ ],
+ error.filter_errors(attribute_checker.normalize_and_verify(ir)),
+ )
- def test_adds_false_is_signed_attribute(self):
- ir = _make_ir_from_emb("enum Foo:\n"
- " ZERO = 0\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- enum = ir.module[0].type[0]
- is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED)
- self.assertTrue(is_signed_attr.expression.HasField("boolean_constant"))
- self.assertFalse(is_signed_attr.expression.boolean_constant.value)
+ def test_adds_false_is_signed_attribute(self):
+ ir = _make_ir_from_emb("enum Foo:\n" " ZERO = 0\n")
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ enum = ir.module[0].type[0]
+ is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED)
+ self.assertTrue(is_signed_attr.expression.HasField("boolean_constant"))
+ self.assertFalse(is_signed_attr.expression.boolean_constant.value)
- def test_leaves_is_signed_attribute(self):
- ir = _make_ir_from_emb("enum Foo:\n"
- " [is_signed: true]\n"
- " ZERO = 0\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- enum = ir.module[0].type[0]
- is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED)
- self.assertTrue(is_signed_attr.expression.HasField("boolean_constant"))
- self.assertTrue(is_signed_attr.expression.boolean_constant.value)
+ def test_leaves_is_signed_attribute(self):
+ ir = _make_ir_from_emb("enum Foo:\n" " [is_signed: true]\n" " ZERO = 0\n")
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ enum = ir.module[0].type[0]
+ is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED)
+ self.assertTrue(is_signed_attr.expression.HasField("boolean_constant"))
+ self.assertTrue(is_signed_attr.expression.boolean_constant.value)
- def test_adds_true_is_signed_attribute(self):
- ir = _make_ir_from_emb("enum Foo:\n"
- " NEGATIVE_ONE = -1\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- enum = ir.module[0].type[0]
- is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED)
- self.assertTrue(is_signed_attr.expression.HasField("boolean_constant"))
- self.assertTrue(is_signed_attr.expression.boolean_constant.value)
+ def test_adds_true_is_signed_attribute(self):
+ ir = _make_ir_from_emb("enum Foo:\n" " NEGATIVE_ONE = -1\n")
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ enum = ir.module[0].type[0]
+ is_signed_attr = ir_util.get_attribute(enum.attribute, _IS_SIGNED)
+ self.assertTrue(is_signed_attr.expression.HasField("boolean_constant"))
+ self.assertTrue(is_signed_attr.expression.boolean_constant.value)
- def test_adds_max_bits_attribute(self):
- ir = _make_ir_from_emb("enum Foo:\n"
- " ZERO = 0\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- enum = ir.module[0].type[0]
- max_bits_attr = ir_util.get_attribute(enum.attribute, _MAX_BITS)
- self.assertTrue(max_bits_attr.expression.HasField("constant"))
- self.assertEqual("64", max_bits_attr.expression.constant.value)
+ def test_adds_max_bits_attribute(self):
+ ir = _make_ir_from_emb("enum Foo:\n" " ZERO = 0\n")
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ enum = ir.module[0].type[0]
+ max_bits_attr = ir_util.get_attribute(enum.attribute, _MAX_BITS)
+ self.assertTrue(max_bits_attr.expression.HasField("constant"))
+ self.assertEqual("64", max_bits_attr.expression.constant.value)
- def test_leaves_max_bits_attribute(self):
- ir = _make_ir_from_emb("enum Foo:\n"
- " [maximum_bits: 32]\n"
- " ZERO = 0\n")
- self.assertEqual([], attribute_checker.normalize_and_verify(ir))
- enum = ir.module[0].type[0]
- max_bits_attr = ir_util.get_attribute(enum.attribute, _MAX_BITS)
- self.assertTrue(max_bits_attr.expression.HasField("constant"))
- self.assertEqual("32", max_bits_attr.expression.constant.value)
+ def test_leaves_max_bits_attribute(self):
+ ir = _make_ir_from_emb("enum Foo:\n" " [maximum_bits: 32]\n" " ZERO = 0\n")
+ self.assertEqual([], attribute_checker.normalize_and_verify(ir))
+ enum = ir.module[0].type[0]
+ max_bits_attr = ir_util.get_attribute(enum.attribute, _MAX_BITS)
+ self.assertTrue(max_bits_attr.expression.HasField("constant"))
+ self.assertEqual("32", max_bits_attr.expression.constant.value)
- def test_rejects_too_small_max_bits(self):
- ir = _make_ir_from_emb("enum Foo:\n"
- " [maximum_bits: 0]\n"
- " ZERO = 0\n")
- attribute_ir = ir.module[0].type[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", attribute_ir.value.source_location,
- "'maximum_bits' on an 'enum' must be between 1 and 64.")]],
- error.filter_errors(attribute_checker.normalize_and_verify(ir)))
+ def test_rejects_too_small_max_bits(self):
+ ir = _make_ir_from_emb("enum Foo:\n" " [maximum_bits: 0]\n" " ZERO = 0\n")
+ attribute_ir = ir.module[0].type[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attribute_ir.value.source_location,
+ "'maximum_bits' on an 'enum' must be between 1 and 64.",
+ )
+ ]
+ ],
+ error.filter_errors(attribute_checker.normalize_and_verify(ir)),
+ )
- def test_rejects_too_large_max_bits(self):
- ir = _make_ir_from_emb("enum Foo:\n"
- " [maximum_bits: 65]\n"
- " ZERO = 0\n")
- attribute_ir = ir.module[0].type[0].attribute[0]
- self.assertEqual(
- [[error.error(
- "m.emb", attribute_ir.value.source_location,
- "'maximum_bits' on an 'enum' must be between 1 and 64.")]],
- error.filter_errors(attribute_checker.normalize_and_verify(ir)))
+ def test_rejects_too_large_max_bits(self):
+ ir = _make_ir_from_emb("enum Foo:\n" " [maximum_bits: 65]\n" " ZERO = 0\n")
+ attribute_ir = ir.module[0].type[0].attribute[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attribute_ir.value.source_location,
+ "'maximum_bits' on an 'enum' must be between 1 and 64.",
+ )
+ ]
+ ],
+ error.filter_errors(attribute_checker.normalize_and_verify(ir)),
+ )
- def test_rejects_unknown_enum_value_attribute(self):
- ir = _make_ir_from_emb("enum Foo:\n"
- " BAR = 0 \n"
- " [bad_attr: true]\n")
- attribute_ir = ir.module[0].type[0].enumeration.value[0].attribute[0]
- self.assertNotEqual([], attribute_checker.normalize_and_verify(ir))
- self.assertEqual(
- [[error.error(
- "m.emb", attribute_ir.name.source_location,
- "Unknown attribute 'bad_attr' on enum value 'BAR'.")]],
- error.filter_errors(attribute_checker.normalize_and_verify(ir)))
+ def test_rejects_unknown_enum_value_attribute(self):
+ ir = _make_ir_from_emb("enum Foo:\n" " BAR = 0 \n" " [bad_attr: true]\n")
+ attribute_ir = ir.module[0].type[0].enumeration.value[0].attribute[0]
+ self.assertNotEqual([], attribute_checker.normalize_and_verify(ir))
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ attribute_ir.name.source_location,
+ "Unknown attribute 'bad_attr' on enum value 'BAR'.",
+ )
+ ]
+ ],
+ error.filter_errors(attribute_checker.normalize_and_verify(ir)),
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/constraints.py b/compiler/front_end/constraints.py
index fabcf04..3249e6d 100644
--- a/compiler/front_end/constraints.py
+++ b/compiler/front_end/constraints.py
@@ -24,588 +24,764 @@
def _render_type(type_ir, ir):
- """Returns the human-readable notation of the given type."""
- assert type_ir.HasField("atomic_type"), (
- "TODO(bolms): Implement _render_type for array types.")
- if type_ir.HasField("size_in_bits"):
- return _render_atomic_type_name(
- type_ir,
- ir,
- suffix=":" + str(ir_util.constant_value(type_ir.size_in_bits)))
- else:
- return _render_atomic_type_name(type_ir, ir)
+ """Returns the human-readable notation of the given type."""
+ assert type_ir.HasField(
+ "atomic_type"
+ ), "TODO(bolms): Implement _render_type for array types."
+ if type_ir.HasField("size_in_bits"):
+ return _render_atomic_type_name(
+ type_ir, ir, suffix=":" + str(ir_util.constant_value(type_ir.size_in_bits))
+ )
+ else:
+ return _render_atomic_type_name(type_ir, ir)
def _render_atomic_type_name(type_ir, ir, suffix=None):
- assert type_ir.HasField("atomic_type"), (
- "_render_atomic_type_name() requires an atomic type")
- if not suffix:
- suffix = ""
- type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir)
- if type_definition.name.is_anonymous:
- return "anonymous type"
- else:
- return "type '{}{}'".format(type_definition.name.name.text, suffix)
+ assert type_ir.HasField(
+ "atomic_type"
+ ), "_render_atomic_type_name() requires an atomic type"
+ if not suffix:
+ suffix = ""
+ type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir)
+ if type_definition.name.is_anonymous:
+ return "anonymous type"
+ else:
+ return "type '{}{}'".format(type_definition.name.name.text, suffix)
-def _check_that_inner_array_dimensions_are_constant(
- type_ir, source_file_name, errors):
- """Checks that inner array dimensions are constant."""
- if type_ir.WhichOneof("size") == "automatic":
- errors.append([error.error(
- source_file_name,
- ir_data_utils.reader(type_ir).element_count.source_location,
- "Array dimensions can only be omitted for the outermost dimension.")])
- elif type_ir.WhichOneof("size") == "element_count":
- if not ir_util.is_constant(type_ir.element_count):
- errors.append([error.error(source_file_name,
- type_ir.element_count.source_location,
- "Inner array dimensions must be constant.")])
- else:
- assert False, 'Expected "element_count" or "automatic" array size.'
+def _check_that_inner_array_dimensions_are_constant(type_ir, source_file_name, errors):
+ """Checks that inner array dimensions are constant."""
+ if type_ir.WhichOneof("size") == "automatic":
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ ir_data_utils.reader(type_ir).element_count.source_location,
+ "Array dimensions can only be omitted for the outermost dimension.",
+ )
+ ]
+ )
+ elif type_ir.WhichOneof("size") == "element_count":
+ if not ir_util.is_constant(type_ir.element_count):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ type_ir.element_count.source_location,
+ "Inner array dimensions must be constant.",
+ )
+ ]
+ )
+ else:
+ assert False, 'Expected "element_count" or "automatic" array size.'
-def _check_that_array_base_types_are_fixed_size(type_ir, source_file_name,
- errors, ir):
- """Checks that the sizes of array elements are known at compile time."""
- if type_ir.base_type.HasField("array_type"):
- # An array is fixed size if its base_type is fixed size and its array
- # dimension is constant. This function will be called again on the inner
- # array, and we do not want to cascade errors if the inner array's base_type
- # is not fixed size. The array dimensions are separately checked by
- # _check_that_inner_array_dimensions_are_constant, which will provide an
- # appropriate error message for that case.
- return
- assert type_ir.base_type.HasField("atomic_type")
- if type_ir.base_type.HasField("size_in_bits"):
- # If the base_type has a size_in_bits, then it is fixed size.
- return
- base_type = ir_util.find_object(type_ir.base_type.atomic_type.reference, ir)
- base_type_fixed_size = ir_util.get_integer_attribute(
- base_type.attribute, attributes.FIXED_SIZE)
- if base_type_fixed_size is None:
- errors.append([error.error(source_file_name,
- type_ir.base_type.atomic_type.source_location,
- "Array elements must be fixed size.")])
+def _check_that_array_base_types_are_fixed_size(type_ir, source_file_name, errors, ir):
+ """Checks that the sizes of array elements are known at compile time."""
+ if type_ir.base_type.HasField("array_type"):
+ # An array is fixed size if its base_type is fixed size and its array
+ # dimension is constant. This function will be called again on the inner
+ # array, and we do not want to cascade errors if the inner array's base_type
+ # is not fixed size. The array dimensions are separately checked by
+ # _check_that_inner_array_dimensions_are_constant, which will provide an
+ # appropriate error message for that case.
+ return
+ assert type_ir.base_type.HasField("atomic_type")
+ if type_ir.base_type.HasField("size_in_bits"):
+ # If the base_type has a size_in_bits, then it is fixed size.
+ return
+ base_type = ir_util.find_object(type_ir.base_type.atomic_type.reference, ir)
+ base_type_fixed_size = ir_util.get_integer_attribute(
+ base_type.attribute, attributes.FIXED_SIZE
+ )
+ if base_type_fixed_size is None:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ type_ir.base_type.atomic_type.source_location,
+ "Array elements must be fixed size.",
+ )
+ ]
+ )
def _check_that_array_base_types_in_structs_are_multiples_of_bytes(
- type_ir, type_definition, source_file_name, errors, ir):
- # TODO(bolms): Remove this limitation.
- """Checks that the sizes of array elements are multiples of 8 bits."""
- if type_ir.base_type.HasField("array_type"):
- # Only check the innermost array for multidimensional arrays.
- return
- assert type_ir.base_type.HasField("atomic_type")
- if type_ir.base_type.HasField("size_in_bits"):
- assert ir_util.is_constant(type_ir.base_type.size_in_bits)
- base_type_size = ir_util.constant_value(type_ir.base_type.size_in_bits)
- else:
- fixed_size = ir_util.fixed_size_of_type_in_bits(type_ir.base_type, ir)
- if fixed_size is None:
- # Variable-sized elements are checked elsewhere.
- return
- base_type_size = fixed_size
- if base_type_size % type_definition.addressable_unit != 0:
- assert type_definition.addressable_unit == ir_data.AddressableUnit.BYTE
- errors.append([error.error(source_file_name,
- type_ir.base_type.source_location,
- "Array elements in structs must have sizes "
- "which are a multiple of 8 bits.")])
+ type_ir, type_definition, source_file_name, errors, ir
+):
+ # TODO(bolms): Remove this limitation.
+ """Checks that the sizes of array elements are multiples of 8 bits."""
+ if type_ir.base_type.HasField("array_type"):
+ # Only check the innermost array for multidimensional arrays.
+ return
+ assert type_ir.base_type.HasField("atomic_type")
+ if type_ir.base_type.HasField("size_in_bits"):
+ assert ir_util.is_constant(type_ir.base_type.size_in_bits)
+ base_type_size = ir_util.constant_value(type_ir.base_type.size_in_bits)
+ else:
+ fixed_size = ir_util.fixed_size_of_type_in_bits(type_ir.base_type, ir)
+ if fixed_size is None:
+ # Variable-sized elements are checked elsewhere.
+ return
+ base_type_size = fixed_size
+ if base_type_size % type_definition.addressable_unit != 0:
+ assert type_definition.addressable_unit == ir_data.AddressableUnit.BYTE
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ type_ir.base_type.source_location,
+ "Array elements in structs must have sizes "
+ "which are a multiple of 8 bits.",
+ )
+ ]
+ )
-def _check_constancy_of_constant_references(expression, source_file_name,
- errors, ir):
- """Checks that constant_references are constant."""
- if expression.WhichOneof("expression") != "constant_reference":
- return
- # This is a bit of a hack: really, we want to know that the referred-to object
- # has no dependencies on any instance variables of its parent structure; i.e.,
- # that its value does not depend on having a view of the structure.
- if not ir_util.is_constant_type(expression.type):
- referred_name = expression.constant_reference.canonical_name
- referred_object = ir_util.find_object(referred_name, ir)
- errors.append([
- error.error(
- source_file_name, expression.source_location,
- "Static references must refer to constants."),
- error.note(
- referred_name.module_file, referred_object.source_location,
- "{} is not constant.".format(referred_name.object_path[-1]))
- ])
+def _check_constancy_of_constant_references(expression, source_file_name, errors, ir):
+ """Checks that constant_references are constant."""
+ if expression.WhichOneof("expression") != "constant_reference":
+ return
+ # This is a bit of a hack: really, we want to know that the referred-to object
+ # has no dependencies on any instance variables of its parent structure; i.e.,
+ # that its value does not depend on having a view of the structure.
+ if not ir_util.is_constant_type(expression.type):
+ referred_name = expression.constant_reference.canonical_name
+ referred_object = ir_util.find_object(referred_name, ir)
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ expression.source_location,
+ "Static references must refer to constants.",
+ ),
+ error.note(
+ referred_name.module_file,
+ referred_object.source_location,
+ "{} is not constant.".format(referred_name.object_path[-1]),
+ ),
+ ]
+ )
-def _check_that_enum_values_are_representable(enum_type, type_definition,
- source_file_name, errors):
- """Checks that enumeration values can fit in their specified int type."""
- values = []
- max_enum_size = ir_util.get_integer_attribute(
- type_definition.attribute, attributes.ENUM_MAXIMUM_BITS)
- is_signed = ir_util.get_boolean_attribute(
- type_definition.attribute, attributes.IS_SIGNED)
- if is_signed:
- enum_range = (-(2**(max_enum_size-1)), 2**(max_enum_size-1)-1)
- else:
- enum_range = (0, 2**max_enum_size-1)
- for value in enum_type.value:
- values.append((ir_util.constant_value(value.value), value))
- out_of_range = [v for v in values
- if not enum_range[0] <= v[0] <= enum_range[1]]
- # If all values are in range, this loop will have zero iterations.
- for value in out_of_range:
- errors.append([
- error.error(
- source_file_name, value[1].value.source_location,
- "Value {} is out of range for {}-bit {} enumeration.".format(
- value[0], max_enum_size, "signed" if is_signed else "unsigned"))
- ])
+def _check_that_enum_values_are_representable(
+ enum_type, type_definition, source_file_name, errors
+):
+ """Checks that enumeration values can fit in their specified int type."""
+ values = []
+ max_enum_size = ir_util.get_integer_attribute(
+ type_definition.attribute, attributes.ENUM_MAXIMUM_BITS
+ )
+ is_signed = ir_util.get_boolean_attribute(
+ type_definition.attribute, attributes.IS_SIGNED
+ )
+ if is_signed:
+ enum_range = (-(2 ** (max_enum_size - 1)), 2 ** (max_enum_size - 1) - 1)
+ else:
+ enum_range = (0, 2**max_enum_size - 1)
+ for value in enum_type.value:
+ values.append((ir_util.constant_value(value.value), value))
+ out_of_range = [v for v in values if not enum_range[0] <= v[0] <= enum_range[1]]
+ # If all values are in range, this loop will have zero iterations.
+ for value in out_of_range:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ value[1].value.source_location,
+ "Value {} is out of range for {}-bit {} enumeration.".format(
+ value[0], max_enum_size, "signed" if is_signed else "unsigned"
+ ),
+ )
+ ]
+ )
def _field_size(field, type_definition):
- """Calculates the size of the given field in bits, if it is constant."""
- size = ir_util.constant_value(field.location.size)
- if size is None:
- return None
- return size * type_definition.addressable_unit
+ """Calculates the size of the given field in bits, if it is constant."""
+ size = ir_util.constant_value(field.location.size)
+ if size is None:
+ return None
+ return size * type_definition.addressable_unit
-def _check_type_requirements_for_field(type_ir, type_definition, field, ir,
- source_file_name, errors):
- """Checks that the `requires` attribute of each field's type is fulfilled."""
- if not type_ir.HasField("atomic_type"):
- return
+def _check_type_requirements_for_field(
+ type_ir, type_definition, field, ir, source_file_name, errors
+):
+ """Checks that the `requires` attribute of each field's type is fulfilled."""
+ if not type_ir.HasField("atomic_type"):
+ return
- if field.type.HasField("atomic_type"):
- field_min_size = (int(field.location.size.type.integer.minimum_value) *
- type_definition.addressable_unit)
- field_max_size = (int(field.location.size.type.integer.maximum_value) *
- type_definition.addressable_unit)
- field_is_atomic = True
- else:
- field_is_atomic = False
+ if field.type.HasField("atomic_type"):
+ field_min_size = (
+ int(field.location.size.type.integer.minimum_value)
+ * type_definition.addressable_unit
+ )
+ field_max_size = (
+ int(field.location.size.type.integer.maximum_value)
+ * type_definition.addressable_unit
+ )
+ field_is_atomic = True
+ else:
+ field_is_atomic = False
- if type_ir.HasField("size_in_bits"):
- element_size = ir_util.constant_value(type_ir.size_in_bits)
- else:
- element_size = None
+ if type_ir.HasField("size_in_bits"):
+ element_size = ir_util.constant_value(type_ir.size_in_bits)
+ else:
+ element_size = None
- referenced_type_definition = ir_util.find_object(
- type_ir.atomic_type.reference, ir)
- type_is_anonymous = referenced_type_definition.name.is_anonymous
- type_size_attr = ir_util.get_attribute(
- referenced_type_definition.attribute, attributes.FIXED_SIZE)
- if type_size_attr:
- type_size = ir_util.constant_value(type_size_attr.expression)
- else:
- type_size = None
+ referenced_type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir)
+ type_is_anonymous = referenced_type_definition.name.is_anonymous
+ type_size_attr = ir_util.get_attribute(
+ referenced_type_definition.attribute, attributes.FIXED_SIZE
+ )
+ if type_size_attr:
+ type_size = ir_util.constant_value(type_size_attr.expression)
+ else:
+ type_size = None
- if (element_size is not None and type_size is not None and
- element_size != type_size):
- errors.append([
- error.error(
- source_file_name, type_ir.size_in_bits.source_location,
- "Explicit size of {} bits does not match fixed size ({} bits) of "
- "{}.".format(element_size, type_size,
- _render_atomic_type_name(type_ir, ir))),
- error.note(
- type_ir.atomic_type.reference.canonical_name.module_file,
- type_size_attr.source_location,
- "Size specified here.")
- ])
- return
+ if element_size is not None and type_size is not None and element_size != type_size:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ type_ir.size_in_bits.source_location,
+ "Explicit size of {} bits does not match fixed size ({} bits) of "
+ "{}.".format(
+ element_size, type_size, _render_atomic_type_name(type_ir, ir)
+ ),
+ ),
+ error.note(
+ type_ir.atomic_type.reference.canonical_name.module_file,
+ type_size_attr.source_location,
+ "Size specified here.",
+ ),
+ ]
+ )
+ return
- # If the type had no size specifier (the ':32' in 'UInt:32'), but the type is
- # fixed size, then continue as if the type's size were explicitly stated.
- if element_size is None:
- element_size = type_size
+ # If the type had no size specifier (the ':32' in 'UInt:32'), but the type is
+ # fixed size, then continue as if the type's size were explicitly stated.
+ if element_size is None:
+ element_size = type_size
- # TODO(bolms): When the full dynamic size expression for types is generated,
- # add a check that dynamically-sized types can, at least potentially, fit in
- # their fields.
+ # TODO(bolms): When the full dynamic size expression for types is generated,
+ # add a check that dynamically-sized types can, at least potentially, fit in
+ # their fields.
- if field_is_atomic and element_size is not None:
- # If the field has a fixed size, and the (atomic) type contained therein is
- # also fixed size, then the sizes should match.
- #
- # TODO(bolms): Maybe change the case where the field is bigger than
- # necessary into a warning?
- if (field_max_size == field_min_size and
- (element_size > field_max_size or
- (element_size < field_min_size and not type_is_anonymous))):
- errors.append([
- error.error(
- source_file_name, type_ir.source_location,
- "Fixed-size {} cannot be placed in field of size {} bits; "
- "requires {} bits.".format(
- _render_type(type_ir, ir), field_max_size, element_size))
- ])
- return
- elif element_size > field_max_size:
- errors.append([
- error.error(
- source_file_name, type_ir.source_location,
- "Field of maximum size {} bits cannot hold fixed-size {}, which "
- "requires {} bits.".format(
- field_max_size, _render_type(type_ir, ir), element_size))
- ])
- return
+ if field_is_atomic and element_size is not None:
+ # If the field has a fixed size, and the (atomic) type contained therein is
+ # also fixed size, then the sizes should match.
+ #
+ # TODO(bolms): Maybe change the case where the field is bigger than
+ # necessary into a warning?
+ if field_max_size == field_min_size and (
+ element_size > field_max_size
+ or (element_size < field_min_size and not type_is_anonymous)
+ ):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ type_ir.source_location,
+ "Fixed-size {} cannot be placed in field of size {} bits; "
+ "requires {} bits.".format(
+ _render_type(type_ir, ir), field_max_size, element_size
+ ),
+ )
+ ]
+ )
+ return
+ elif element_size > field_max_size:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ type_ir.source_location,
+ "Field of maximum size {} bits cannot hold fixed-size {}, which "
+ "requires {} bits.".format(
+ field_max_size, _render_type(type_ir, ir), element_size
+ ),
+ )
+ ]
+ )
+ return
- # If we're here, then field/type sizes are consistent.
- if (element_size is None and field_is_atomic and
- field_min_size == field_max_size):
- # From here down, we just use element_size.
- element_size = field_min_size
+ # If we're here, then field/type sizes are consistent.
+ if element_size is None and field_is_atomic and field_min_size == field_max_size:
+ # From here down, we just use element_size.
+ element_size = field_min_size
- errors.extend(_check_physical_type_requirements(
- type_ir, field.source_location, element_size, ir, source_file_name))
+ errors.extend(
+ _check_physical_type_requirements(
+ type_ir, field.source_location, element_size, ir, source_file_name
+ )
+ )
def _check_type_requirements_for_parameter_type(
- runtime_parameter, ir, source_file_name, errors):
- """Checks that the type of a parameter is valid."""
- physical_type = runtime_parameter.physical_type_alias
- logical_type = runtime_parameter.type
- size = ir_util.constant_value(physical_type.size_in_bits)
- if logical_type.WhichOneof("type") == "integer":
- integer_errors = _integer_bounds_errors(
- logical_type.integer, "parameter", source_file_name,
- physical_type.source_location)
- if integer_errors:
- errors.extend(integer_errors)
- return
- errors.extend(_check_physical_type_requirements(
- physical_type, runtime_parameter.source_location,
- size, ir, source_file_name))
- elif logical_type.WhichOneof("type") == "enumeration":
- if physical_type.HasField("size_in_bits"):
- # This seems a little weird: for `UInt`, `Int`, etc., the explicit size is
- # required, but for enums it is banned. This is because enums have a
- # "native" 64-bit size in expressions, so the physical size is just
- # ignored.
- errors.extend([[
- error.error(
- source_file_name, physical_type.size_in_bits.source_location,
- "Parameters with enum type may not have explicit size.")
-
- ]])
- else:
- assert False, "Non-integer/enum parameters should have been caught earlier."
+ runtime_parameter, ir, source_file_name, errors
+):
+ """Checks that the type of a parameter is valid."""
+ physical_type = runtime_parameter.physical_type_alias
+ logical_type = runtime_parameter.type
+ size = ir_util.constant_value(physical_type.size_in_bits)
+ if logical_type.WhichOneof("type") == "integer":
+ integer_errors = _integer_bounds_errors(
+ logical_type.integer,
+ "parameter",
+ source_file_name,
+ physical_type.source_location,
+ )
+ if integer_errors:
+ errors.extend(integer_errors)
+ return
+ errors.extend(
+ _check_physical_type_requirements(
+ physical_type,
+ runtime_parameter.source_location,
+ size,
+ ir,
+ source_file_name,
+ )
+ )
+ elif logical_type.WhichOneof("type") == "enumeration":
+ if physical_type.HasField("size_in_bits"):
+ # This seems a little weird: for `UInt`, `Int`, etc., the explicit size is
+ # required, but for enums it is banned. This is because enums have a
+ # "native" 64-bit size in expressions, so the physical size is just
+ # ignored.
+ errors.extend(
+ [
+ [
+ error.error(
+ source_file_name,
+ physical_type.size_in_bits.source_location,
+ "Parameters with enum type may not have explicit size.",
+ )
+ ]
+ ]
+ )
+ else:
+ assert False, "Non-integer/enum parameters should have been caught earlier."
def _check_physical_type_requirements(
- type_ir, usage_source_location, size, ir, source_file_name):
- """Checks that the given atomic `type_ir` is allowed to be `size` bits."""
- referenced_type_definition = ir_util.find_object(
- type_ir.atomic_type.reference, ir)
- if referenced_type_definition.HasField("enumeration"):
+ type_ir, usage_source_location, size, ir, source_file_name
+):
+ """Checks that the given atomic `type_ir` is allowed to be `size` bits."""
+ referenced_type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir)
+ if referenced_type_definition.HasField("enumeration"):
+ if size is None:
+ return [
+ [
+ error.error(
+ source_file_name,
+ type_ir.source_location,
+ "Enumeration {} cannot be placed in a dynamically-sized "
+ "field.".format(_render_type(type_ir, ir)),
+ )
+ ]
+ ]
+ else:
+ max_enum_size = ir_util.get_integer_attribute(
+ referenced_type_definition.attribute, attributes.ENUM_MAXIMUM_BITS
+ )
+ if size < 1 or size > max_enum_size:
+ return [
+ [
+ error.error(
+ source_file_name,
+ type_ir.source_location,
+ "Enumeration {} cannot be {} bits; {} must be between "
+ "1 and {} bits, inclusive.".format(
+ _render_atomic_type_name(type_ir, ir),
+ size,
+ _render_atomic_type_name(type_ir, ir),
+ max_enum_size,
+ ),
+ )
+ ]
+ ]
+
if size is None:
- return [[
- error.error(
- source_file_name, type_ir.source_location,
- "Enumeration {} cannot be placed in a dynamically-sized "
- "field.".format(_render_type(type_ir, ir)))
- ]]
+ bindings = {"$is_statically_sized": False}
else:
- max_enum_size = ir_util.get_integer_attribute(
- referenced_type_definition.attribute, attributes.ENUM_MAXIMUM_BITS)
- if size < 1 or size > max_enum_size:
- return [[
- error.error(
- source_file_name, type_ir.source_location,
- "Enumeration {} cannot be {} bits; {} must be between "
- "1 and {} bits, inclusive.".format(
- _render_atomic_type_name(type_ir, ir), size,
- _render_atomic_type_name(type_ir, ir), max_enum_size))
- ]]
-
- if size is None:
- bindings = {"$is_statically_sized": False}
- else:
- bindings = {
- "$is_statically_sized": True,
- "$static_size_in_bits": size
- }
- requires_attr = ir_util.get_attribute(
- referenced_type_definition.attribute, attributes.STATIC_REQUIREMENTS)
- if requires_attr and not ir_util.constant_value(requires_attr.expression,
- bindings):
- # TODO(bolms): Figure out a better way to build this error message.
- # The "Requirements specified here." message should print out the actual
- # source text of the requires attribute, so that should help, but it's still
- # a bit generic and unfriendly.
- return [[
- error.error(
- source_file_name, usage_source_location,
- "Requirements of {} not met.".format(
- type_ir.atomic_type.reference.canonical_name.object_path[-1])),
- error.note(
- type_ir.atomic_type.reference.canonical_name.module_file,
- requires_attr.source_location,
- "Requirements specified here.")
- ]]
- return []
+ bindings = {"$is_statically_sized": True, "$static_size_in_bits": size}
+ requires_attr = ir_util.get_attribute(
+ referenced_type_definition.attribute, attributes.STATIC_REQUIREMENTS
+ )
+ if requires_attr and not ir_util.constant_value(requires_attr.expression, bindings):
+ # TODO(bolms): Figure out a better way to build this error message.
+ # The "Requirements specified here." message should print out the actual
+ # source text of the requires attribute, so that should help, but it's still
+ # a bit generic and unfriendly.
+ return [
+ [
+ error.error(
+ source_file_name,
+ usage_source_location,
+ "Requirements of {} not met.".format(
+ type_ir.atomic_type.reference.canonical_name.object_path[-1]
+ ),
+ ),
+ error.note(
+ type_ir.atomic_type.reference.canonical_name.module_file,
+ requires_attr.source_location,
+ "Requirements specified here.",
+ ),
+ ]
+ ]
+ return []
-def _check_allowed_in_bits(type_ir, type_definition, source_file_name, ir,
- errors):
- if not type_ir.HasField("atomic_type"):
- return
- referenced_type_definition = ir_util.find_object(
- type_ir.atomic_type.reference, ir)
- if (type_definition.addressable_unit %
- referenced_type_definition.addressable_unit != 0):
- assert type_definition.addressable_unit == ir_data.AddressableUnit.BIT
- assert (referenced_type_definition.addressable_unit ==
- ir_data.AddressableUnit.BYTE)
- errors.append([
- error.error(source_file_name, type_ir.source_location,
+def _check_allowed_in_bits(type_ir, type_definition, source_file_name, ir, errors):
+ if not type_ir.HasField("atomic_type"):
+ return
+ referenced_type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir)
+ if (
+ type_definition.addressable_unit % referenced_type_definition.addressable_unit
+ != 0
+ ):
+ assert type_definition.addressable_unit == ir_data.AddressableUnit.BIT
+ assert (
+ referenced_type_definition.addressable_unit == ir_data.AddressableUnit.BYTE
+ )
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ type_ir.source_location,
"Byte-oriented {} cannot be used in a bits field.".format(
- _render_type(type_ir, ir)))
- ])
+ _render_type(type_ir, ir)
+ ),
+ )
+ ]
+ )
def _check_size_of_bits(type_ir, type_definition, source_file_name, errors):
- """Checks that `bits` types are fixed size, less than 64 bits."""
- del type_ir # Unused
- if type_definition.addressable_unit != ir_data.AddressableUnit.BIT:
- return
- fixed_size = ir_util.get_integer_attribute(
- type_definition.attribute, attributes.FIXED_SIZE)
- if fixed_size is None:
- errors.append([error.error(source_file_name,
- type_definition.source_location,
- "`bits` types must be fixed size.")])
- return
- if fixed_size > 64:
- errors.append([error.error(source_file_name,
- type_definition.source_location,
- "`bits` types must be 64 bits or smaller.")])
+ """Checks that `bits` types are fixed size, less than 64 bits."""
+ del type_ir # Unused
+ if type_definition.addressable_unit != ir_data.AddressableUnit.BIT:
+ return
+ fixed_size = ir_util.get_integer_attribute(
+ type_definition.attribute, attributes.FIXED_SIZE
+ )
+ if fixed_size is None:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ type_definition.source_location,
+ "`bits` types must be fixed size.",
+ )
+ ]
+ )
+ return
+ if fixed_size > 64:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ type_definition.source_location,
+ "`bits` types must be 64 bits or smaller.",
+ )
+ ]
+ )
_RESERVED_WORDS = None
def get_reserved_word_list():
- if _RESERVED_WORDS is None:
- _initialize_reserved_word_list()
- return _RESERVED_WORDS
+ if _RESERVED_WORDS is None:
+ _initialize_reserved_word_list()
+ return _RESERVED_WORDS
def _initialize_reserved_word_list():
- global _RESERVED_WORDS
- _RESERVED_WORDS = {}
- language = None
- for line in resources.load(
- "compiler.front_end", "reserved_words").splitlines():
- stripped_line = line.partition("#")[0].strip()
- if not stripped_line:
- continue
- if stripped_line.startswith("--"):
- language = stripped_line.partition("--")[2].strip()
- else:
- # For brevity's sake, only use the first language for error messages.
- if stripped_line not in _RESERVED_WORDS:
- _RESERVED_WORDS[stripped_line] = language
+ global _RESERVED_WORDS
+ _RESERVED_WORDS = {}
+ language = None
+ for line in resources.load("compiler.front_end", "reserved_words").splitlines():
+ stripped_line = line.partition("#")[0].strip()
+ if not stripped_line:
+ continue
+ if stripped_line.startswith("--"):
+ language = stripped_line.partition("--")[2].strip()
+ else:
+ # For brevity's sake, only use the first language for error messages.
+ if stripped_line not in _RESERVED_WORDS:
+ _RESERVED_WORDS[stripped_line] = language
def _check_name_for_reserved_words(obj, source_file_name, errors, context_name):
- if obj.name.name.text in get_reserved_word_list():
- errors.append([
- error.error(
- source_file_name, obj.name.name.source_location,
- "{} reserved word may not be used as {}.".format(
- get_reserved_word_list()[obj.name.name.text],
- context_name))
- ])
+ if obj.name.name.text in get_reserved_word_list():
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ obj.name.name.source_location,
+ "{} reserved word may not be used as {}.".format(
+ get_reserved_word_list()[obj.name.name.text], context_name
+ ),
+ )
+ ]
+ )
def _check_field_name_for_reserved_words(field, source_file_name, errors):
- return _check_name_for_reserved_words(field, source_file_name, errors,
- "a field name")
+ return _check_name_for_reserved_words(
+ field, source_file_name, errors, "a field name"
+ )
def _check_enum_name_for_reserved_words(enum, source_file_name, errors):
- return _check_name_for_reserved_words(enum, source_file_name, errors,
- "an enum name")
+ return _check_name_for_reserved_words(
+ enum, source_file_name, errors, "an enum name"
+ )
-def _check_type_name_for_reserved_words(type_definition, source_file_name,
- errors):
- return _check_name_for_reserved_words(
- type_definition, source_file_name, errors, "a type name")
+def _check_type_name_for_reserved_words(type_definition, source_file_name, errors):
+ return _check_name_for_reserved_words(
+ type_definition, source_file_name, errors, "a type name"
+ )
def _bounds_can_fit_64_bit_unsigned(minimum, maximum):
- return minimum >= 0 and maximum <= 2**64 - 1
+ return minimum >= 0 and maximum <= 2**64 - 1
def _bounds_can_fit_64_bit_signed(minimum, maximum):
- return minimum >= -(2**63) and maximum <= 2**63 - 1
+ return minimum >= -(2**63) and maximum <= 2**63 - 1
def _bounds_can_fit_any_64_bit_integer_type(minimum, maximum):
- return (_bounds_can_fit_64_bit_unsigned(minimum, maximum) or
- _bounds_can_fit_64_bit_signed(minimum, maximum))
+ return _bounds_can_fit_64_bit_unsigned(
+ minimum, maximum
+ ) or _bounds_can_fit_64_bit_signed(minimum, maximum)
def _integer_bounds_errors_for_expression(expression, source_file_name):
- """Checks that `expression` is in range for int64_t or uint64_t."""
- # Only check non-constant subexpressions.
- if (expression.WhichOneof("expression") == "function" and
- not ir_util.is_constant_type(expression.type)):
- errors = []
- for arg in expression.function.args:
- errors += _integer_bounds_errors_for_expression(arg, source_file_name)
- if errors:
- # Don't cascade bounds errors: report them at the lowest level they
- # appear.
- return errors
- if expression.type.WhichOneof("type") == "integer":
- errors = _integer_bounds_errors(expression.type.integer, "expression",
- source_file_name,
- expression.source_location)
- if errors:
- return errors
- if (expression.WhichOneof("expression") == "function" and
- not ir_util.is_constant_type(expression.type)):
- int64_only_clauses = []
- uint64_only_clauses = []
- for clause in [expression] + list(expression.function.args):
- if clause.type.WhichOneof("type") == "integer":
- arg_minimum = int(clause.type.integer.minimum_value)
- arg_maximum = int(clause.type.integer.maximum_value)
- if not _bounds_can_fit_64_bit_signed(arg_minimum, arg_maximum):
- uint64_only_clauses.append(clause)
- elif not _bounds_can_fit_64_bit_unsigned(arg_minimum, arg_maximum):
- int64_only_clauses.append(clause)
- if int64_only_clauses and uint64_only_clauses:
- error_set = [
- error.error(
- source_file_name, expression.source_location,
- "Either all arguments to '{}' and its result must fit in a "
- "64-bit unsigned integer, or all must fit in a 64-bit signed "
- "integer.".format(expression.function.function_name.text))
- ]
- for signedness, clause_list in (("unsigned", uint64_only_clauses),
- ("signed", int64_only_clauses)):
- for clause in clause_list:
- error_set.append(error.note(
- source_file_name, clause.source_location,
- "Requires {} 64-bit integer.".format(signedness)))
- return [error_set]
- return []
+ """Checks that `expression` is in range for int64_t or uint64_t."""
+ # Only check non-constant subexpressions.
+ if expression.WhichOneof(
+ "expression"
+ ) == "function" and not ir_util.is_constant_type(expression.type):
+ errors = []
+ for arg in expression.function.args:
+ errors += _integer_bounds_errors_for_expression(arg, source_file_name)
+ if errors:
+ # Don't cascade bounds errors: report them at the lowest level they
+ # appear.
+ return errors
+ if expression.type.WhichOneof("type") == "integer":
+ errors = _integer_bounds_errors(
+ expression.type.integer,
+ "expression",
+ source_file_name,
+ expression.source_location,
+ )
+ if errors:
+ return errors
+ if expression.WhichOneof(
+ "expression"
+ ) == "function" and not ir_util.is_constant_type(expression.type):
+ int64_only_clauses = []
+ uint64_only_clauses = []
+ for clause in [expression] + list(expression.function.args):
+ if clause.type.WhichOneof("type") == "integer":
+ arg_minimum = int(clause.type.integer.minimum_value)
+ arg_maximum = int(clause.type.integer.maximum_value)
+ if not _bounds_can_fit_64_bit_signed(arg_minimum, arg_maximum):
+ uint64_only_clauses.append(clause)
+ elif not _bounds_can_fit_64_bit_unsigned(arg_minimum, arg_maximum):
+ int64_only_clauses.append(clause)
+ if int64_only_clauses and uint64_only_clauses:
+ error_set = [
+ error.error(
+ source_file_name,
+ expression.source_location,
+ "Either all arguments to '{}' and its result must fit in a "
+ "64-bit unsigned integer, or all must fit in a 64-bit signed "
+ "integer.".format(expression.function.function_name.text),
+ )
+ ]
+ for signedness, clause_list in (
+ ("unsigned", uint64_only_clauses),
+ ("signed", int64_only_clauses),
+ ):
+ for clause in clause_list:
+ error_set.append(
+ error.note(
+ source_file_name,
+ clause.source_location,
+ "Requires {} 64-bit integer.".format(signedness),
+ )
+ )
+ return [error_set]
+ return []
-def _integer_bounds_errors(bounds, name, source_file_name,
- error_source_location):
- """Returns appropriate errors, if any, for the given integer bounds."""
- assert bounds.minimum_value, "{}".format(bounds)
- assert bounds.maximum_value, "{}".format(bounds)
- if (bounds.minimum_value == "-infinity" or
- bounds.maximum_value == "infinity"):
- return [[
- error.error(
- source_file_name, error_source_location,
- "Integer range of {} must not be unbounded; it must fit "
- "in a 64-bit signed or unsigned integer.".format(name))
- ]]
- if not _bounds_can_fit_any_64_bit_integer_type(int(bounds.minimum_value),
- int(bounds.maximum_value)):
- if int(bounds.minimum_value) == int(bounds.maximum_value):
- return [[
- error.error(
- source_file_name, error_source_location,
- "Constant value {} of {} cannot fit in a 64-bit signed or "
- "unsigned integer.".format(bounds.minimum_value, name))
- ]]
- else:
- return [[
- error.error(
- source_file_name, error_source_location,
- "Potential range of {} is {} to {}, which cannot fit "
- "in a 64-bit signed or unsigned integer.".format(
- name, bounds.minimum_value, bounds.maximum_value))
- ]]
- return []
+def _integer_bounds_errors(bounds, name, source_file_name, error_source_location):
+ """Returns appropriate errors, if any, for the given integer bounds."""
+ assert bounds.minimum_value, "{}".format(bounds)
+ assert bounds.maximum_value, "{}".format(bounds)
+ if bounds.minimum_value == "-infinity" or bounds.maximum_value == "infinity":
+ return [
+ [
+ error.error(
+ source_file_name,
+ error_source_location,
+ "Integer range of {} must not be unbounded; it must fit "
+ "in a 64-bit signed or unsigned integer.".format(name),
+ )
+ ]
+ ]
+ if not _bounds_can_fit_any_64_bit_integer_type(
+ int(bounds.minimum_value), int(bounds.maximum_value)
+ ):
+ if int(bounds.minimum_value) == int(bounds.maximum_value):
+ return [
+ [
+ error.error(
+ source_file_name,
+ error_source_location,
+ "Constant value {} of {} cannot fit in a 64-bit signed or "
+ "unsigned integer.".format(bounds.minimum_value, name),
+ )
+ ]
+ ]
+ else:
+ return [
+ [
+ error.error(
+ source_file_name,
+ error_source_location,
+ "Potential range of {} is {} to {}, which cannot fit "
+ "in a 64-bit signed or unsigned integer.".format(
+ name, bounds.minimum_value, bounds.maximum_value
+ ),
+ )
+ ]
+ ]
+ return []
-def _check_bounds_on_runtime_integer_expressions(expression, source_file_name,
- in_attribute, errors):
- if in_attribute and in_attribute.name.text == attributes.STATIC_REQUIREMENTS:
- # [static_requirements] is never evaluated at runtime, and $size_in_bits is
- # unbounded, so it should not be checked.
- return
- # The logic for gathering errors and suppressing cascades is simpler if
- # errors are just returned, rather than appended to a shared list.
- errors += _integer_bounds_errors_for_expression(expression, source_file_name)
+def _check_bounds_on_runtime_integer_expressions(
+ expression, source_file_name, in_attribute, errors
+):
+ if in_attribute and in_attribute.name.text == attributes.STATIC_REQUIREMENTS:
+ # [static_requirements] is never evaluated at runtime, and $size_in_bits is
+ # unbounded, so it should not be checked.
+ return
+ # The logic for gathering errors and suppressing cascades is simpler if
+ # errors are just returned, rather than appended to a shared list.
+ errors += _integer_bounds_errors_for_expression(expression, source_file_name)
+
def _attribute_in_attribute_action(a):
- return {"in_attribute": a}
+ return {"in_attribute": a}
+
def check_constraints(ir):
- """Checks miscellaneous validity constraints in ir.
+ """Checks miscellaneous validity constraints in ir.
- Checks that auto array sizes are only used for the outermost size of
- multidimensional arrays. That is, Type[3][] is OK, but Type[][3] is not.
+ Checks that auto array sizes are only used for the outermost size of
+ multidimensional arrays. That is, Type[3][] is OK, but Type[][3] is not.
- Checks that fixed-size fields are a correct size to hold statically-sized
- types.
+ Checks that fixed-size fields are a correct size to hold statically-sized
+ types.
- Checks that inner array dimensions are constant.
+ Checks that inner array dimensions are constant.
- Checks that only constant-size types are used in arrays.
+ Checks that only constant-size types are used in arrays.
- Arguments:
- ir: An ir_data.EmbossIr object to check.
+ Arguments:
+ ir: An ir_data.EmbossIr object to check.
- Returns:
- A list of ConstraintViolations, or an empty list if there are none.
- """
- errors = []
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Structure, ir_data.Type], _check_allowed_in_bits,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- # TODO(bolms): look for [ir_data.ArrayType], [ir_data.AtomicType], and
- # simplify _check_that_array_base_types_are_fixed_size.
- ir, [ir_data.ArrayType], _check_that_array_base_types_are_fixed_size,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Structure, ir_data.ArrayType],
- _check_that_array_base_types_in_structs_are_multiples_of_bytes,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.ArrayType, ir_data.ArrayType],
- _check_that_inner_array_dimensions_are_constant,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Structure], _check_size_of_bits,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Structure, ir_data.Type], _check_type_requirements_for_field,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Field], _check_field_name_for_reserved_words,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.EnumValue], _check_enum_name_for_reserved_words,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.TypeDefinition], _check_type_name_for_reserved_words,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Expression], _check_constancy_of_constant_references,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Enum], _check_that_enum_values_are_representable,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Expression], _check_bounds_on_runtime_integer_expressions,
- incidental_actions={ir_data.Attribute: _attribute_in_attribute_action},
- skip_descendants_of={ir_data.EnumValue, ir_data.Expression},
- parameters={"errors": errors, "in_attribute": None})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.RuntimeParameter],
- _check_type_requirements_for_parameter_type,
- parameters={"errors": errors})
- return errors
+ Returns:
+ A list of ConstraintViolations, or an empty list if there are none.
+ """
+ errors = []
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Structure, ir_data.Type],
+ _check_allowed_in_bits,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ # TODO(bolms): look for [ir_data.ArrayType], [ir_data.AtomicType], and
+ # simplify _check_that_array_base_types_are_fixed_size.
+ ir,
+ [ir_data.ArrayType],
+ _check_that_array_base_types_are_fixed_size,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Structure, ir_data.ArrayType],
+ _check_that_array_base_types_in_structs_are_multiples_of_bytes,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.ArrayType, ir_data.ArrayType],
+ _check_that_inner_array_dimensions_are_constant,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir, [ir_data.Structure], _check_size_of_bits, parameters={"errors": errors}
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Structure, ir_data.Type],
+ _check_type_requirements_for_field,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Field],
+ _check_field_name_for_reserved_words,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.EnumValue],
+ _check_enum_name_for_reserved_words,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.TypeDefinition],
+ _check_type_name_for_reserved_words,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Expression],
+ _check_constancy_of_constant_references,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Enum],
+ _check_that_enum_values_are_representable,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Expression],
+ _check_bounds_on_runtime_integer_expressions,
+ incidental_actions={ir_data.Attribute: _attribute_in_attribute_action},
+ skip_descendants_of={ir_data.EnumValue, ir_data.Expression},
+ parameters={"errors": errors, "in_attribute": None},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.RuntimeParameter],
+ _check_type_requirements_for_parameter_type,
+ parameters={"errors": errors},
+ )
+ return errors
diff --git a/compiler/front_end/constraints_test.py b/compiler/front_end/constraints_test.py
index eac1232..4f4df7e 100644
--- a/compiler/front_end/constraints_test.py
+++ b/compiler/front_end/constraints_test.py
@@ -25,819 +25,1291 @@
def _make_ir_from_emb(emb_text, name="m.emb"):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- name,
- test_util.dict_file_reader({name: emb_text}),
- stop_before_step="check_constraints")
- assert not errors, repr(errors)
- return ir
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ name,
+ test_util.dict_file_reader({name: emb_text}),
+ stop_before_step="check_constraints",
+ )
+ assert not errors, repr(errors)
+ return ir
class ConstraintsTest(unittest.TestCase):
- """Tests constraints.check_constraints and helpers."""
+ """Tests constraints.check_constraints and helpers."""
- def test_error_on_missing_inner_array_size(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] UInt:8[][1] one_byte\n")
- # There is a latent issue here where the source location reported in this
- # error is using a default value of 0:0. An issue is filed at
- # https://github.com/google/emboss/issues/153 for further investigation.
- # In the meantime we use `ir_data_utils.reader` to mimic this legacy
- # behavior.
- error_array = ir_data_utils.reader(
- ir.module[0].type[0].structure.field[0].type.array_type)
- self.assertEqual([[
- error.error(
- "m.emb",
- error_array.base_type.array_type.element_count.source_location,
- "Array dimensions can only be omitted for the outermost dimension.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_error_on_missing_inner_array_size(self):
+ ir = _make_ir_from_emb("struct Foo:\n" " 0 [+1] UInt:8[][1] one_byte\n")
+ # There is a latent issue here where the source location reported in this
+ # error is using a default value of 0:0. An issue is filed at
+ # https://github.com/google/emboss/issues/153 for further investigation.
+ # In the meantime we use `ir_data_utils.reader` to mimic this legacy
+ # behavior.
+ error_array = ir_data_utils.reader(
+ ir.module[0].type[0].structure.field[0].type.array_type
+ )
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_array.base_type.array_type.element_count.source_location,
+ "Array dimensions can only be omitted for the outermost dimension.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_no_error_on_ok_array_size(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] UInt:8[1][1] one_byte\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_no_error_on_ok_array_size(self):
+ ir = _make_ir_from_emb("struct Foo:\n" " 0 [+1] UInt:8[1][1] one_byte\n")
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_no_error_on_ok_missing_outer_array_size(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] UInt:8[1][] one_byte\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_no_error_on_ok_missing_outer_array_size(self):
+ ir = _make_ir_from_emb("struct Foo:\n" " 0 [+1] UInt:8[1][] one_byte\n")
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_no_error_on_dynamically_sized_struct_in_dynamically_sized_field(
- self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] UInt size\n"
- " 1 [+size] Bar bar\n"
- "struct Bar:\n"
- " 0 [+1] UInt size\n"
- " 1 [+size] UInt:8[] payload\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_no_error_on_dynamically_sized_struct_in_dynamically_sized_field(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+1] UInt size\n"
+ " 1 [+size] Bar bar\n"
+ "struct Bar:\n"
+ " 0 [+1] UInt size\n"
+ " 1 [+size] UInt:8[] payload\n"
+ )
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_no_error_on_dynamically_sized_struct_in_statically_sized_field(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+10] Bar bar\n"
- "struct Bar:\n"
- " 0 [+1] UInt size\n"
- " 1 [+size] UInt:8[] payload\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_no_error_on_dynamically_sized_struct_in_statically_sized_field(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+10] Bar bar\n"
+ "struct Bar:\n"
+ " 0 [+1] UInt size\n"
+ " 1 [+size] UInt:8[] payload\n"
+ )
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_no_error_non_fixed_size_outer_array_dimension(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] UInt size\n"
- " 1 [+size] UInt:8[1][size-1] one_byte\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_no_error_non_fixed_size_outer_array_dimension(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+1] UInt size\n"
+ " 1 [+size] UInt:8[1][size-1] one_byte\n"
+ )
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_error_non_fixed_size_inner_array_dimension(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] UInt size\n"
- " 1 [+size] UInt:8[size-1][1] one_byte\n")
- error_array = ir.module[0].type[0].structure.field[1].type.array_type
- self.assertEqual([[
- error.error(
- "m.emb",
- error_array.base_type.array_type.element_count.source_location,
- "Inner array dimensions must be constant.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_error_non_fixed_size_inner_array_dimension(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+1] UInt size\n"
+ " 1 [+size] UInt:8[size-1][1] one_byte\n"
+ )
+ error_array = ir.module[0].type[0].structure.field[1].type.array_type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_array.base_type.array_type.element_count.source_location,
+ "Inner array dimensions must be constant.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_error_non_constant_inner_array_dimensions(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] Bar[1] one_byte\n"
- # There is no dynamically-sized byte-oriented type in
- # the Prelude, so this test has to make its own.
- "external Bar:\n"
- " [is_integer: true]\n"
- " [addressable_unit_size: 8]\n")
- error_array = ir.module[0].type[0].structure.field[0].type.array_type
- self.assertEqual([[
- error.error(
- "m.emb", error_array.base_type.atomic_type.source_location,
- "Array elements must be fixed size.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_error_non_constant_inner_array_dimensions(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+1] Bar[1] one_byte\n"
+ # There is no dynamically-sized byte-oriented type in
+ # the Prelude, so this test has to make its own.
+ "external Bar:\n"
+ " [is_integer: true]\n"
+ " [addressable_unit_size: 8]\n"
+ )
+ error_array = ir.module[0].type[0].structure.field[0].type.array_type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_array.base_type.atomic_type.source_location,
+ "Array elements must be fixed size.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_error_dynamically_sized_array_elements(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+1] Bar[1] bar\n"
- "struct Bar:\n"
- " 0 [+1] UInt size\n"
- " 1 [+size] UInt:8[] payload\n")
- error_array = ir.module[0].type[0].structure.field[0].type.array_type
- self.assertEqual([[
- error.error(
- "m.emb", error_array.base_type.atomic_type.source_location,
- "Array elements must be fixed size.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_error_dynamically_sized_array_elements(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+1] Bar[1] bar\n"
+ "struct Bar:\n"
+ " 0 [+1] UInt size\n"
+ " 1 [+size] UInt:8[] payload\n"
+ )
+ error_array = ir.module[0].type[0].structure.field[0].type.array_type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_array.base_type.atomic_type.source_location,
+ "Array elements must be fixed size.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_field_too_small_for_type(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+1] Bar bar\n"
- "struct Bar:\n"
- " 0 [+2] UInt value\n")
- error_type = ir.module[0].type[0].structure.field[0].type
- self.assertEqual([[
- error.error(
- "m.emb", error_type.source_location,
- "Fixed-size type 'Bar' cannot be placed in field of size 8 bits; "
- "requires 16 bits.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_field_too_small_for_type(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+1] Bar bar\n"
+ "struct Bar:\n"
+ " 0 [+2] UInt value\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[0].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "Fixed-size type 'Bar' cannot be placed in field of size 8 bits; "
+ "requires 16 bits.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_dynamically_sized_field_always_too_small_for_type(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+1] bits:\n"
- " 0 [+1] UInt x\n"
- " 0 [+x] Bar bar\n"
- "struct Bar:\n"
- " 0 [+2] UInt value\n")
- error_type = ir.module[0].type[0].structure.field[2].type
- self.assertEqual([[
- error.error(
- "m.emb", error_type.source_location,
- "Field of maximum size 8 bits cannot hold fixed-size type 'Bar', "
- "which requires 16 bits.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_dynamically_sized_field_always_too_small_for_type(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+1] bits:\n"
+ " 0 [+1] UInt x\n"
+ " 0 [+x] Bar bar\n"
+ "struct Bar:\n"
+ " 0 [+2] UInt value\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[2].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "Field of maximum size 8 bits cannot hold fixed-size type 'Bar', "
+ "which requires 16 bits.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_struct_field_too_big_for_type(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+2] Byte double_byte\n"
- "struct Byte:\n"
- " 0 [+1] UInt b\n")
- error_type = ir.module[0].type[0].structure.field[0].type
- self.assertEqual([[
- error.error(
- "m.emb", error_type.source_location,
- "Fixed-size type 'Byte' cannot be placed in field of size 16 bits; "
- "requires 8 bits.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_struct_field_too_big_for_type(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+2] Byte double_byte\n"
+ "struct Byte:\n"
+ " 0 [+1] UInt b\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[0].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "Fixed-size type 'Byte' cannot be placed in field of size 16 bits; "
+ "requires 8 bits.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_bits_field_too_big_for_type(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+9] UInt uint72\n"
- ' [byte_order: "LittleEndian"]\n')
- error_field = ir.module[0].type[0].structure.field[0]
- uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir)
- uint_requirements = ir_util.get_attribute(uint_type.attribute,
- attributes.STATIC_REQUIREMENTS)
- self.assertEqual([[
- error.error("m.emb", error_field.source_location,
- "Requirements of UInt not met."),
- error.note("", uint_requirements.source_location,
- "Requirements specified here."),
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_bits_field_too_big_for_type(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+9] UInt uint72\n"
+ ' [byte_order: "LittleEndian"]\n'
+ )
+ error_field = ir.module[0].type[0].structure.field[0]
+ uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir)
+ uint_requirements = ir_util.get_attribute(
+ uint_type.attribute, attributes.STATIC_REQUIREMENTS
+ )
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_field.source_location,
+ "Requirements of UInt not met.",
+ ),
+ error.note(
+ "",
+ uint_requirements.source_location,
+ "Requirements specified here.",
+ ),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_field_type_not_allowed_in_bits(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "bits Foo:\n"
- " 0 [+16] Bar bar\n"
- "external Bar:\n"
- " [addressable_unit_size: 8]\n")
- error_type = ir.module[0].type[0].structure.field[0].type
- self.assertEqual([[
- error.error(
- "m.emb", error_type.source_location,
- "Byte-oriented type 'Bar' cannot be used in a bits field.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_field_type_not_allowed_in_bits(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "bits Foo:\n"
+ " 0 [+16] Bar bar\n"
+ "external Bar:\n"
+ " [addressable_unit_size: 8]\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[0].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "Byte-oriented type 'Bar' cannot be used in a bits field.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_arrays_allowed_in_bits(self):
- ir = _make_ir_from_emb("bits Foo:\n"
- " 0 [+16] Flag[16] bar\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_arrays_allowed_in_bits(self):
+ ir = _make_ir_from_emb("bits Foo:\n" " 0 [+16] Flag[16] bar\n")
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_oversized_anonymous_bit_field(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+4] bits:\n"
- " 0 [+8] UInt field\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_oversized_anonymous_bit_field(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+4] bits:\n"
+ " 0 [+8] UInt field\n"
+ )
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_undersized_anonymous_bit_field(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+1] bits:\n"
- " 0 [+32] UInt field\n")
- error_type = ir.module[0].type[0].structure.field[0].type
- self.assertEqual([[
- error.error(
- "m.emb", error_type.source_location,
- "Fixed-size anonymous type cannot be placed in field of size 8 "
- "bits; requires 32 bits.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_undersized_anonymous_bit_field(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+1] bits:\n"
+ " 0 [+32] UInt field\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[0].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "Fixed-size anonymous type cannot be placed in field of size 8 "
+ "bits; requires 32 bits.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_reserved_field_name(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+8] UInt restrict\n")
- error_name = ir.module[0].type[0].structure.field[0].name.name
- self.assertEqual([[
- error.error(
- "m.emb", error_name.source_location,
- "C reserved word may not be used as a field name.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_reserved_field_name(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+8] UInt restrict\n"
+ )
+ error_name = ir.module[0].type[0].structure.field[0].name.name
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_name.source_location,
+ "C reserved word may not be used as a field name.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_reserved_type_name(self):
- ir = _make_ir_from_emb("struct False:\n"
- " 0 [+1] UInt foo\n")
- error_name = ir.module[0].type[0].name.name
- self.assertEqual([[
- error.error(
- "m.emb", error_name.source_location,
- "Python 3 reserved word may not be used as a type name.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_reserved_type_name(self):
+ ir = _make_ir_from_emb("struct False:\n" " 0 [+1] UInt foo\n")
+ error_name = ir.module[0].type[0].name.name
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_name.source_location,
+ "Python 3 reserved word may not be used as a type name.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_reserved_enum_name(self):
- ir = _make_ir_from_emb("enum Foo:\n"
- " NULL = 1\n")
- error_name = ir.module[0].type[0].enumeration.value[0].name.name
- self.assertEqual([[
- error.error(
- "m.emb", error_name.source_location,
- "C reserved word may not be used as an enum name.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_reserved_enum_name(self):
+ ir = _make_ir_from_emb("enum Foo:\n" " NULL = 1\n")
+ error_name = ir.module[0].type[0].enumeration.value[0].name.name
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_name.source_location,
+ "C reserved word may not be used as an enum name.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_bits_type_in_struct_array(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+10] UInt:8[10] array\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_bits_type_in_struct_array(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+10] UInt:8[10] array\n"
+ )
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_bits_type_in_bits_array(self):
- ir = _make_ir_from_emb("bits Foo:\n"
- " 0 [+10] UInt:8[10] array\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_bits_type_in_bits_array(self):
+ ir = _make_ir_from_emb("bits Foo:\n" " 0 [+10] UInt:8[10] array\n")
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_explicit_size_too_small(self):
- ir = _make_ir_from_emb("bits Foo:\n"
- " 0 [+0] UInt:0 zero_bit\n")
- error_field = ir.module[0].type[0].structure.field[0]
- uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir)
- uint_requirements = ir_util.get_attribute(uint_type.attribute,
- attributes.STATIC_REQUIREMENTS)
- self.assertEqual([[
- error.error("m.emb", error_field.source_location,
- "Requirements of UInt not met."),
- error.note("", uint_requirements.source_location,
- "Requirements specified here."),
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_explicit_size_too_small(self):
+ ir = _make_ir_from_emb("bits Foo:\n" " 0 [+0] UInt:0 zero_bit\n")
+ error_field = ir.module[0].type[0].structure.field[0]
+ uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir)
+ uint_requirements = ir_util.get_attribute(
+ uint_type.attribute, attributes.STATIC_REQUIREMENTS
+ )
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_field.source_location,
+ "Requirements of UInt not met.",
+ ),
+ error.note(
+ "",
+ uint_requirements.source_location,
+ "Requirements specified here.",
+ ),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_explicit_enumeration_size_too_small(self):
- ir = _make_ir_from_emb('[$default byte_order: "BigEndian"]\n'
- "bits Foo:\n"
- " 0 [+0] Bar:0 zero_bit\n"
- "enum Bar:\n"
- " BAZ = 0\n")
- error_type = ir.module[0].type[0].structure.field[0].type
- self.assertEqual([[
- error.error("m.emb", error_type.source_location,
- "Enumeration type 'Bar' cannot be 0 bits; type 'Bar' "
- "must be between 1 and 64 bits, inclusive."),
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_explicit_enumeration_size_too_small(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "BigEndian"]\n'
+ "bits Foo:\n"
+ " 0 [+0] Bar:0 zero_bit\n"
+ "enum Bar:\n"
+ " BAZ = 0\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[0].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "Enumeration type 'Bar' cannot be 0 bits; type 'Bar' "
+ "must be between 1 and 64 bits, inclusive.",
+ ),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_explicit_size_too_big_for_field(self):
- ir = _make_ir_from_emb("bits Foo:\n"
- " 0 [+8] UInt:32 thirty_two_bit\n")
- error_type = ir.module[0].type[0].structure.field[0].type
- self.assertEqual([[
- error.error(
- "m.emb", error_type.source_location,
- "Fixed-size type 'UInt:32' cannot be placed in field of size 8 "
- "bits; requires 32 bits.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_explicit_size_too_big_for_field(self):
+ ir = _make_ir_from_emb("bits Foo:\n" " 0 [+8] UInt:32 thirty_two_bit\n")
+ error_type = ir.module[0].type[0].structure.field[0].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "Fixed-size type 'UInt:32' cannot be placed in field of size 8 "
+ "bits; requires 32 bits.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_explicit_size_too_small_for_field(self):
- ir = _make_ir_from_emb("bits Foo:\n"
- " 0 [+64] UInt:32 thirty_two_bit\n")
- error_type = ir.module[0].type[0].structure.field[0].type
- self.assertEqual([[
- error.error("m.emb", error_type.source_location,
- "Fixed-size type 'UInt:32' cannot be placed in field of "
- "size 64 bits; requires 32 bits.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_explicit_size_too_small_for_field(self):
+ ir = _make_ir_from_emb("bits Foo:\n" " 0 [+64] UInt:32 thirty_two_bit\n")
+ error_type = ir.module[0].type[0].structure.field[0].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "Fixed-size type 'UInt:32' cannot be placed in field of "
+ "size 64 bits; requires 32 bits.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_explicit_size_too_big(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+16] UInt:128 one_twenty_eight_bit\n"
- ' [byte_order: "LittleEndian"]\n')
- error_field = ir.module[0].type[0].structure.field[0]
- uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir)
- uint_requirements = ir_util.get_attribute(uint_type.attribute,
- attributes.STATIC_REQUIREMENTS)
- self.assertEqual([[
- error.error("m.emb", error_field.source_location,
- "Requirements of UInt not met."),
- error.note("", uint_requirements.source_location,
- "Requirements specified here."),
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_explicit_size_too_big(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+16] UInt:128 one_twenty_eight_bit\n"
+ ' [byte_order: "LittleEndian"]\n'
+ )
+ error_field = ir.module[0].type[0].structure.field[0]
+ uint_type = ir_util.find_object(error_field.type.atomic_type.reference, ir)
+ uint_requirements = ir_util.get_attribute(
+ uint_type.attribute, attributes.STATIC_REQUIREMENTS
+ )
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_field.source_location,
+ "Requirements of UInt not met.",
+ ),
+ error.note(
+ "",
+ uint_requirements.source_location,
+ "Requirements specified here.",
+ ),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_explicit_enumeration_size_too_big(self):
- ir = _make_ir_from_emb('[$default byte_order: "BigEndian"]\n'
- "struct Foo:\n"
- " 0 [+9] Bar seventy_two_bit\n"
- "enum Bar:\n"
- " BAZ = 0\n")
- error_type = ir.module[0].type[0].structure.field[0].type
- self.assertEqual([[
- error.error("m.emb", error_type.source_location,
- "Enumeration type 'Bar' cannot be 72 bits; type 'Bar' " +
- "must be between 1 and 64 bits, inclusive."),
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_explicit_enumeration_size_too_big(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "BigEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+9] Bar seventy_two_bit\n"
+ "enum Bar:\n"
+ " BAZ = 0\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[0].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "Enumeration type 'Bar' cannot be 72 bits; type 'Bar' "
+ + "must be between 1 and 64 bits, inclusive.",
+ ),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_explicit_enumeration_size_too_big_for_small_enum(self):
- ir = _make_ir_from_emb('[$default byte_order: "BigEndian"]\n'
- "struct Foo:\n"
- " 0 [+8] Bar sixty_four_bit\n"
- "enum Bar:\n"
- " [maximum_bits: 63]\n"
- " BAZ = 0\n")
- error_type = ir.module[0].type[0].structure.field[0].type
- self.assertEqual([[
- error.error("m.emb", error_type.source_location,
- "Enumeration type 'Bar' cannot be 64 bits; type 'Bar' " +
- "must be between 1 and 63 bits, inclusive."),
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_explicit_enumeration_size_too_big_for_small_enum(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "BigEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+8] Bar sixty_four_bit\n"
+ "enum Bar:\n"
+ " [maximum_bits: 63]\n"
+ " BAZ = 0\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[0].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "Enumeration type 'Bar' cannot be 64 bits; type 'Bar' "
+ + "must be between 1 and 63 bits, inclusive.",
+ ),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_explicit_size_on_fixed_size_type(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] Byte:8 one_byte\n"
- "struct Byte:\n"
- " 0 [+1] UInt b\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_explicit_size_on_fixed_size_type(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+1] Byte:8 one_byte\n"
+ "struct Byte:\n"
+ " 0 [+1] UInt b\n"
+ )
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_explicit_size_too_small_on_fixed_size_type(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+0] Byte:0 null_byte\n"
- "struct Byte:\n"
- " 0 [+1] UInt b\n")
- error_type = ir.module[0].type[0].structure.field[0].type
- self.assertEqual([[
- error.error(
- "m.emb", error_type.size_in_bits.source_location,
- "Explicit size of 0 bits does not match fixed size (8 bits) of "
- "type 'Byte'."),
- error.note("m.emb", ir.module[0].type[1].source_location,
- "Size specified here."),
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_explicit_size_too_small_on_fixed_size_type(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+0] Byte:0 null_byte\n"
+ "struct Byte:\n"
+ " 0 [+1] UInt b\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[0].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.size_in_bits.source_location,
+ "Explicit size of 0 bits does not match fixed size (8 bits) of "
+ "type 'Byte'.",
+ ),
+ error.note(
+ "m.emb",
+ ir.module[0].type[1].source_location,
+ "Size specified here.",
+ ),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_explicit_size_too_big_on_fixed_size_type(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+2] Byte:16 double_byte\n"
- "struct Byte:\n"
- " 0 [+1] UInt b\n")
- error_type = ir.module[0].type[0].structure.field[0].type
- self.assertEqual([[
- error.error(
- "m.emb", error_type.size_in_bits.source_location,
- "Explicit size of 16 bits does not match fixed size (8 bits) of "
- "type 'Byte'."),
- error.note(
- "m.emb", ir.module[0].type[1].source_location,
- "Size specified here."),
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_explicit_size_too_big_on_fixed_size_type(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+2] Byte:16 double_byte\n"
+ "struct Byte:\n"
+ " 0 [+1] UInt b\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[0].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.size_in_bits.source_location,
+ "Explicit size of 16 bits does not match fixed size (8 bits) of "
+ "type 'Byte'.",
+ ),
+ error.note(
+ "m.emb",
+ ir.module[0].type[1].source_location,
+ "Size specified here.",
+ ),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_explicit_size_ignored_on_variable_size_type(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+1] UInt n\n"
- " 1 [+n] UInt:8[] d\n"
- "struct Bar:\n"
- " 0 [+10] Foo:80 foo\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_explicit_size_ignored_on_variable_size_type(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+1] UInt n\n"
+ " 1 [+n] UInt:8[] d\n"
+ "struct Bar:\n"
+ " 0 [+10] Foo:80 foo\n"
+ )
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_fixed_size_type_in_dynamically_sized_field(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] UInt bar\n"
- " 0 [+bar] Byte one_byte\n"
- "struct Byte:\n"
- " 0 [+1] UInt b\n")
- self.assertEqual([], constraints.check_constraints(ir))
+ def test_fixed_size_type_in_dynamically_sized_field(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n"
+ " 0 [+1] UInt bar\n"
+ " 0 [+bar] Byte one_byte\n"
+ "struct Byte:\n"
+ " 0 [+1] UInt b\n"
+ )
+ self.assertEqual([], constraints.check_constraints(ir))
- def test_enum_in_dynamically_sized_field(self):
- ir = _make_ir_from_emb('[$default byte_order: "BigEndian"]\n'
- "struct Foo:\n"
- " 0 [+1] UInt bar\n"
- " 0 [+bar] Baz baz\n"
- "enum Baz:\n"
- " QUX = 0\n")
- error_type = ir.module[0].type[0].structure.field[1].type
- self.assertEqual(
- [[
- error.error("m.emb", error_type.source_location,
+ def test_enum_in_dynamically_sized_field(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "BigEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+1] UInt bar\n"
+ " 0 [+bar] Baz baz\n"
+ "enum Baz:\n"
+ " QUX = 0\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[1].type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
"Enumeration type 'Baz' cannot be placed in a "
- "dynamically-sized field.")
- ]],
- error.filter_errors(constraints.check_constraints(ir)))
+ "dynamically-sized field.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_enum_value_too_high(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "enum Foo:\n"
- " HIGH = 0x1_0000_0000_0000_0000\n")
- error_value = ir.module[0].type[0].enumeration.value[0].value
- self.assertEqual([
- [error.error(
- "m.emb", error_value.source_location,
- # TODO(bolms): Try to print numbers like 2**64 in hex? (I.e., if a
- # number is a round number in hex, but not in decimal, print in
- # hex?)
- "Value 18446744073709551616 is out of range for 64-bit unsigned " +
- "enumeration.")]
- ], constraints.check_constraints(ir))
+ def test_enum_value_too_high(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "enum Foo:\n"
+ " HIGH = 0x1_0000_0000_0000_0000\n"
+ )
+ error_value = ir.module[0].type[0].enumeration.value[0].value
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_value.source_location,
+ # TODO(bolms): Try to print numbers like 2**64 in hex? (I.e., if a
+ # number is a round number in hex, but not in decimal, print in
+ # hex?)
+ "Value 18446744073709551616 is out of range for 64-bit unsigned "
+ + "enumeration.",
+ )
+ ]
+ ],
+ constraints.check_constraints(ir),
+ )
- def test_enum_value_too_low(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "enum Foo:\n"
- " LOW = -0x8000_0000_0000_0001\n")
- error_value = ir.module[0].type[0].enumeration.value[0].value
- self.assertEqual([
- [error.error(
- "m.emb", error_value.source_location,
- "Value -9223372036854775809 is out of range for 64-bit signed " +
- "enumeration.")]
- ], constraints.check_constraints(ir))
+ def test_enum_value_too_low(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "enum Foo:\n"
+ " LOW = -0x8000_0000_0000_0001\n"
+ )
+ error_value = ir.module[0].type[0].enumeration.value[0].value
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_value.source_location,
+ "Value -9223372036854775809 is out of range for 64-bit signed "
+ + "enumeration.",
+ )
+ ]
+ ],
+ constraints.check_constraints(ir),
+ )
- def test_enum_value_too_wide(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "enum Foo:\n"
- " LOW = -1\n"
- " HIGH = 0x8000_0000_0000_0000\n")
- error_value = ir.module[0].type[0].enumeration.value[1].value
- self.assertEqual([[
- error.error(
- "m.emb", error_value.source_location,
- "Value 9223372036854775808 is out of range for 64-bit signed " +
- "enumeration.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_enum_value_too_wide(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "enum Foo:\n"
+ " LOW = -1\n"
+ " HIGH = 0x8000_0000_0000_0000\n"
+ )
+ error_value = ir.module[0].type[0].enumeration.value[1].value
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_value.source_location,
+ "Value 9223372036854775808 is out of range for 64-bit signed "
+ + "enumeration.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_enum_value_too_wide_unsigned_error_message(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "enum Foo:\n"
- " LOW = -2\n"
- " LOW2 = -1\n"
- " HIGH = 0x8000_0000_0000_0000\n")
- error_value = ir.module[0].type[0].enumeration.value[2].value
- self.assertEqual([[
- error.error(
- "m.emb", error_value.source_location,
- "Value 9223372036854775808 is out of range for 64-bit signed " +
- "enumeration.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_enum_value_too_wide_unsigned_error_message(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "enum Foo:\n"
+ " LOW = -2\n"
+ " LOW2 = -1\n"
+ " HIGH = 0x8000_0000_0000_0000\n"
+ )
+ error_value = ir.module[0].type[0].enumeration.value[2].value
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_value.source_location,
+ "Value 9223372036854775808 is out of range for 64-bit signed "
+ + "enumeration.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_enum_value_too_wide_small_size_error_message(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "enum Foo:\n"
- " [maximum_bits: 8]\n"
- " HIGH = 0x100\n")
- error_value = ir.module[0].type[0].enumeration.value[0].value
- self.assertEqual([[
- error.error(
- "m.emb", error_value.source_location,
- "Value 256 is out of range for 8-bit unsigned enumeration.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_enum_value_too_wide_small_size_error_message(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "enum Foo:\n"
+ " [maximum_bits: 8]\n"
+ " HIGH = 0x100\n"
+ )
+ error_value = ir.module[0].type[0].enumeration.value[0].value
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_value.source_location,
+ "Value 256 is out of range for 8-bit unsigned enumeration.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_enum_value_too_wide_small_size_signed_error_message(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "enum Foo:\n"
- " [maximum_bits: 8]\n"
- " [is_signed: true]\n"
- " HIGH = 0x80\n")
- error_value = ir.module[0].type[0].enumeration.value[0].value
- self.assertEqual([[
- error.error(
- "m.emb", error_value.source_location,
- "Value 128 is out of range for 8-bit signed enumeration.")
- ]], error.filter_errors(constraints.check_constraints(ir)))
+ def test_enum_value_too_wide_small_size_signed_error_message(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "enum Foo:\n"
+ " [maximum_bits: 8]\n"
+ " [is_signed: true]\n"
+ " HIGH = 0x80\n"
+ )
+ error_value = ir.module[0].type[0].enumeration.value[0].value
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_value.source_location,
+ "Value 128 is out of range for 8-bit signed enumeration.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_enum_value_too_wide_multiple(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "enum Foo:\n"
- " LOW = -2\n"
- " LOW2 = -1\n"
- " HIGH = 0x8000_0000_0000_0000\n"
- " HIGH2 = 0x8000_0000_0000_0001\n")
- error_value = ir.module[0].type[0].enumeration.value[2].value
- error_value2 = ir.module[0].type[0].enumeration.value[3].value
- self.assertEqual([
- [error.error(
- "m.emb", error_value.source_location,
- "Value 9223372036854775808 is out of range for 64-bit signed " +
- "enumeration.")],
- [error.error(
- "m.emb", error_value2.source_location,
- "Value 9223372036854775809 is out of range for 64-bit signed " +
- "enumeration.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_enum_value_too_wide_multiple(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "enum Foo:\n"
+ " LOW = -2\n"
+ " LOW2 = -1\n"
+ " HIGH = 0x8000_0000_0000_0000\n"
+ " HIGH2 = 0x8000_0000_0000_0001\n"
+ )
+ error_value = ir.module[0].type[0].enumeration.value[2].value
+ error_value2 = ir.module[0].type[0].enumeration.value[3].value
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_value.source_location,
+ "Value 9223372036854775808 is out of range for 64-bit signed "
+ + "enumeration.",
+ )
+ ],
+ [
+ error.error(
+ "m.emb",
+ error_value2.source_location,
+ "Value 9223372036854775809 is out of range for 64-bit signed "
+ + "enumeration.",
+ )
+ ],
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_enum_value_too_wide_multiple_signed_error_message(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "enum Foo:\n"
- " LOW = -3\n"
- " LOW2 = -2\n"
- " LOW3 = -1\n"
- " HIGH = 0x8000_0000_0000_0000\n"
- " HIGH2 = 0x8000_0000_0000_0001\n")
- error_value = ir.module[0].type[0].enumeration.value[3].value
- error_value2 = ir.module[0].type[0].enumeration.value[4].value
- self.assertEqual([
- [error.error(
- "m.emb", error_value.source_location,
- "Value 9223372036854775808 is out of range for 64-bit signed "
- "enumeration.")],
- [error.error(
- "m.emb", error_value2.source_location,
- "Value 9223372036854775809 is out of range for 64-bit signed "
- "enumeration.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_enum_value_too_wide_multiple_signed_error_message(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "enum Foo:\n"
+ " LOW = -3\n"
+ " LOW2 = -2\n"
+ " LOW3 = -1\n"
+ " HIGH = 0x8000_0000_0000_0000\n"
+ " HIGH2 = 0x8000_0000_0000_0001\n"
+ )
+ error_value = ir.module[0].type[0].enumeration.value[3].value
+ error_value2 = ir.module[0].type[0].enumeration.value[4].value
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_value.source_location,
+ "Value 9223372036854775808 is out of range for 64-bit signed "
+ "enumeration.",
+ )
+ ],
+ [
+ error.error(
+ "m.emb",
+ error_value2.source_location,
+ "Value 9223372036854775809 is out of range for 64-bit signed "
+ "enumeration.",
+ )
+ ],
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_enum_value_mixed_error_message(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "enum Foo:\n"
- " LOW = -1\n"
- " HIGH = 0x8000_0000_0000_0000\n"
- " HIGH2 = 0x1_0000_0000_0000_0000\n")
- error_value1 = ir.module[0].type[0].enumeration.value[1].value
- error_value2 = ir.module[0].type[0].enumeration.value[2].value
- self.assertEqual([
- [error.error(
- "m.emb", error_value1.source_location,
- "Value 9223372036854775808 is out of range for 64-bit signed " +
- "enumeration.")],
- [error.error(
- "m.emb", error_value2.source_location,
- "Value 18446744073709551616 is out of range for 64-bit signed " +
- "enumeration.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_enum_value_mixed_error_message(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "enum Foo:\n"
+ " LOW = -1\n"
+ " HIGH = 0x8000_0000_0000_0000\n"
+ " HIGH2 = 0x1_0000_0000_0000_0000\n"
+ )
+ error_value1 = ir.module[0].type[0].enumeration.value[1].value
+ error_value2 = ir.module[0].type[0].enumeration.value[2].value
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_value1.source_location,
+ "Value 9223372036854775808 is out of range for 64-bit signed "
+ + "enumeration.",
+ )
+ ],
+ [
+ error.error(
+ "m.emb",
+ error_value2.source_location,
+ "Value 18446744073709551616 is out of range for 64-bit signed "
+ + "enumeration.",
+ )
+ ],
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_enum_value_explicitly_signed_error_message(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "enum Foo:\n"
- " [is_signed: true]\n"
- " HIGH = 0x8000_0000_0000_0000\n"
- " HIGH2 = 0x1_0000_0000_0000_0000\n")
- error_value0 = ir.module[0].type[0].enumeration.value[0].value
- error_value1 = ir.module[0].type[0].enumeration.value[1].value
- self.assertEqual([
- [error.error(
- "m.emb", error_value0.source_location,
- "Value 9223372036854775808 is out of range for 64-bit signed " +
- "enumeration.")],
- [error.error(
- "m.emb", error_value1.source_location,
- "Value 18446744073709551616 is out of range for 64-bit signed " +
- "enumeration.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_enum_value_explicitly_signed_error_message(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "enum Foo:\n"
+ " [is_signed: true]\n"
+ " HIGH = 0x8000_0000_0000_0000\n"
+ " HIGH2 = 0x1_0000_0000_0000_0000\n"
+ )
+ error_value0 = ir.module[0].type[0].enumeration.value[0].value
+ error_value1 = ir.module[0].type[0].enumeration.value[1].value
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_value0.source_location,
+ "Value 9223372036854775808 is out of range for 64-bit signed "
+ + "enumeration.",
+ )
+ ],
+ [
+ error.error(
+ "m.emb",
+ error_value1.source_location,
+ "Value 18446744073709551616 is out of range for 64-bit signed "
+ + "enumeration.",
+ )
+ ],
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_enum_value_explicitly_unsigned_error_message(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "enum Foo:\n"
- " [is_signed: false]\n"
- " LOW = -1\n"
- " HIGH = 0x8000_0000_0000_0000\n"
- " HIGH2 = 0x1_0000_0000_0000_0000\n")
- error_value0 = ir.module[0].type[0].enumeration.value[0].value
- error_value2 = ir.module[0].type[0].enumeration.value[2].value
- self.assertEqual([
- [error.error(
- "m.emb", error_value0.source_location,
- "Value -1 is out of range for 64-bit unsigned enumeration.")],
- [error.error(
- "m.emb", error_value2.source_location,
- "Value 18446744073709551616 is out of range for 64-bit unsigned " +
- "enumeration.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_enum_value_explicitly_unsigned_error_message(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "enum Foo:\n"
+ " [is_signed: false]\n"
+ " LOW = -1\n"
+ " HIGH = 0x8000_0000_0000_0000\n"
+ " HIGH2 = 0x1_0000_0000_0000_0000\n"
+ )
+ error_value0 = ir.module[0].type[0].enumeration.value[0].value
+ error_value2 = ir.module[0].type[0].enumeration.value[2].value
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_value0.source_location,
+ "Value -1 is out of range for 64-bit unsigned enumeration.",
+ )
+ ],
+ [
+ error.error(
+ "m.emb",
+ error_value2.source_location,
+ "Value 18446744073709551616 is out of range for 64-bit unsigned "
+ + "enumeration.",
+ )
+ ],
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_explicit_non_byte_size_array_element(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+2] UInt:4[4] nibbles\n")
- error_type = ir.module[0].type[0].structure.field[0].type.array_type
- self.assertEqual([
- [error.error(
- "m.emb", error_type.base_type.source_location,
- "Array elements in structs must have sizes which are a multiple of "
- "8 bits.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_explicit_non_byte_size_array_element(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+2] UInt:4[4] nibbles\n"
+ )
+ error_type = ir.module[0].type[0].structure.field[0].type.array_type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.base_type.source_location,
+ "Array elements in structs must have sizes which are a multiple of "
+ "8 bits.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_implicit_non_byte_size_array_element(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "bits Nibble:\n"
- " 0 [+4] UInt nibble\n"
- "struct Foo:\n"
- " 0 [+2] Nibble[4] nibbles\n")
- error_type = ir.module[0].type[1].structure.field[0].type.array_type
- self.assertEqual([
- [error.error(
- "m.emb", error_type.base_type.source_location,
- "Array elements in structs must have sizes which are a multiple of "
- "8 bits.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_implicit_non_byte_size_array_element(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "bits Nibble:\n"
+ " 0 [+4] UInt nibble\n"
+ "struct Foo:\n"
+ " 0 [+2] Nibble[4] nibbles\n"
+ )
+ error_type = ir.module[0].type[1].structure.field[0].type.array_type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.base_type.source_location,
+ "Array elements in structs must have sizes which are a multiple of "
+ "8 bits.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_bits_must_be_fixed_size(self):
- ir = _make_ir_from_emb("bits Dynamic:\n"
- " 0 [+3] UInt x\n"
- " 3 [+3 * x] UInt:3[x] a\n")
- error_type = ir.module[0].type[0]
- self.assertEqual([
- [error.error("m.emb", error_type.source_location,
- "`bits` types must be fixed size.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_bits_must_be_fixed_size(self):
+ ir = _make_ir_from_emb(
+ "bits Dynamic:\n"
+ " 0 [+3] UInt x\n"
+ " 3 [+3 * x] UInt:3[x] a\n"
+ )
+ error_type = ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "`bits` types must be fixed size.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_bits_must_be_small(self):
- ir = _make_ir_from_emb("bits Big:\n"
- " 0 [+64] UInt x\n"
- " 64 [+1] UInt y\n")
- error_type = ir.module[0].type[0]
- self.assertEqual([
- [error.error("m.emb", error_type.source_location,
- "`bits` types must be 64 bits or smaller.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_bits_must_be_small(self):
+ ir = _make_ir_from_emb(
+ "bits Big:\n" " 0 [+64] UInt x\n" " 64 [+1] UInt y\n"
+ )
+ error_type = ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_type.source_location,
+ "`bits` types must be 64 bits or smaller.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_constant_expressions_must_be_small(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+8] UInt x\n"
- " if x < 0x1_0000_0000_0000_0000:\n"
- " 8 [+1] UInt y\n")
- condition = ir.module[0].type[0].structure.field[1].existence_condition
- error_location = condition.function.args[1].source_location
- self.assertEqual([
- [error.error(
- "m.emb", error_location,
- "Constant value {} of expression cannot fit in a 64-bit signed or "
- "unsigned integer.".format(2**64))]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_constant_expressions_must_be_small(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+8] UInt x\n"
+ " if x < 0x1_0000_0000_0000_0000:\n"
+ " 8 [+1] UInt y\n"
+ )
+ condition = ir.module[0].type[0].structure.field[1].existence_condition
+ error_location = condition.function.args[1].source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_location,
+ "Constant value {} of expression cannot fit in a 64-bit signed or "
+ "unsigned integer.".format(2**64),
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_variable_expression_out_of_range_for_uint64(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+8] UInt x\n"
- " if x + 1 < 0xffff_ffff_ffff_ffff:\n"
- " 8 [+1] UInt y\n")
- condition = ir.module[0].type[0].structure.field[1].existence_condition
- error_location = condition.function.args[0].source_location
- self.assertEqual([
- [error.error(
- "m.emb", error_location,
- "Potential range of expression is {} to {}, which cannot fit in a "
- "64-bit signed or unsigned integer.".format(1, 2**64))]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_variable_expression_out_of_range_for_uint64(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+8] UInt x\n"
+ " if x + 1 < 0xffff_ffff_ffff_ffff:\n"
+ " 8 [+1] UInt y\n"
+ )
+ condition = ir.module[0].type[0].structure.field[1].existence_condition
+ error_location = condition.function.args[0].source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_location,
+ "Potential range of expression is {} to {}, which cannot fit in a "
+ "64-bit signed or unsigned integer.".format(1, 2**64),
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_variable_expression_out_of_range_for_int64(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+8] UInt x\n"
- " if x - 0x8000_0000_0000_0001 < 0:\n"
- " 8 [+1] UInt y\n")
- condition = ir.module[0].type[0].structure.field[1].existence_condition
- error_location = condition.function.args[0].source_location
- self.assertEqual([
- [error.error(
- "m.emb", error_location,
- "Potential range of expression is {} to {}, which cannot fit in a "
- "64-bit signed or unsigned integer.".format(-(2**63) - 1,
- 2**63 - 2))]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_variable_expression_out_of_range_for_int64(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+8] UInt x\n"
+ " if x - 0x8000_0000_0000_0001 < 0:\n"
+ " 8 [+1] UInt y\n"
+ )
+ condition = ir.module[0].type[0].structure.field[1].existence_condition
+ error_location = condition.function.args[0].source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_location,
+ "Potential range of expression is {} to {}, which cannot fit in a "
+ "64-bit signed or unsigned integer.".format(
+ -(2**63) - 1, 2**63 - 2
+ ),
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_requires_expression_out_of_range_for_uint64(self):
- ir = _make_ir_from_emb('[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+8] UInt x\n"
- " [requires: this * 2 < 0x1_0000]\n")
- attribute_list = ir.module[0].type[0].structure.field[0].attribute
- error_arg = attribute_list[0].value.expression.function.args[0]
- error_location = error_arg.source_location
- self.assertEqual(
- [[
- error.error(
- "m.emb", error_location,
- "Potential range of expression is {} to {}, which cannot fit "
- "in a 64-bit signed or unsigned integer.".format(0, 2**65-2))
- ]],
- error.filter_errors(constraints.check_constraints(ir)))
+ def test_requires_expression_out_of_range_for_uint64(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+8] UInt x\n"
+ " [requires: this * 2 < 0x1_0000]\n"
+ )
+ attribute_list = ir.module[0].type[0].structure.field[0].attribute
+ error_arg = attribute_list[0].value.expression.function.args[0]
+ error_location = error_arg.source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_location,
+ "Potential range of expression is {} to {}, which cannot fit "
+ "in a 64-bit signed or unsigned integer.".format(0, 2**65 - 2),
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_arguments_require_different_signedness_64_bits(self):
- ir = _make_ir_from_emb(
- '[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+1] UInt x\n"
- # Left side requires uint64, right side requires int64.
- " if (x + 0x8000_0000_0000_0000) + (x - 0x7fff_ffff_ffff_ffff) < 10:\n"
- " 1 [+1] UInt y\n")
- condition = ir.module[0].type[0].structure.field[1].existence_condition
- error_expression = condition.function.args[0]
- error_location = error_expression.source_location
- arg0_location = error_expression.function.args[0].source_location
- arg1_location = error_expression.function.args[1].source_location
- self.assertEqual([
- [error.error(
- "m.emb", error_location,
- "Either all arguments to '+' and its result must fit in a 64-bit "
- "unsigned integer, or all must fit in a 64-bit signed integer."),
- error.note("m.emb", arg0_location,
- "Requires unsigned 64-bit integer."),
- error.note("m.emb", arg1_location,
- "Requires signed 64-bit integer.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_arguments_require_different_signedness_64_bits(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ # Left side requires uint64, right side requires int64.
+ " if (x + 0x8000_0000_0000_0000) + (x - 0x7fff_ffff_ffff_ffff) < 10:\n"
+ " 1 [+1] UInt y\n"
+ )
+ condition = ir.module[0].type[0].structure.field[1].existence_condition
+ error_expression = condition.function.args[0]
+ error_location = error_expression.source_location
+ arg0_location = error_expression.function.args[0].source_location
+ arg1_location = error_expression.function.args[1].source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_location,
+ "Either all arguments to '+' and its result must fit in a 64-bit "
+ "unsigned integer, or all must fit in a 64-bit signed integer.",
+ ),
+ error.note(
+ "m.emb", arg0_location, "Requires unsigned 64-bit integer."
+ ),
+ error.note(
+ "m.emb", arg1_location, "Requires signed 64-bit integer."
+ ),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_return_value_requires_different_signedness_from_arguments(self):
- ir = _make_ir_from_emb(
- '[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+1] UInt x\n"
- # Both arguments require uint64; result fits in int64.
- " if (x + 0x7fff_ffff_ffff_ffff) - 0x8000_0000_0000_0000 < 10:\n"
- " 1 [+1] UInt y\n")
- condition = ir.module[0].type[0].structure.field[1].existence_condition
- error_expression = condition.function.args[0]
- error_location = error_expression.source_location
- arg0_location = error_expression.function.args[0].source_location
- arg1_location = error_expression.function.args[1].source_location
- self.assertEqual([
- [error.error(
- "m.emb", error_location,
- "Either all arguments to '-' and its result must fit in a 64-bit "
- "unsigned integer, or all must fit in a 64-bit signed integer."),
- error.note("m.emb", arg0_location,
- "Requires unsigned 64-bit integer."),
- error.note("m.emb", arg1_location,
- "Requires unsigned 64-bit integer."),
- error.note("m.emb", error_location,
- "Requires signed 64-bit integer.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_return_value_requires_different_signedness_from_arguments(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ # Both arguments require uint64; result fits in int64.
+ " if (x + 0x7fff_ffff_ffff_ffff) - 0x8000_0000_0000_0000 < 10:\n"
+ " 1 [+1] UInt y\n"
+ )
+ condition = ir.module[0].type[0].structure.field[1].existence_condition
+ error_expression = condition.function.args[0]
+ error_location = error_expression.source_location
+ arg0_location = error_expression.function.args[0].source_location
+ arg1_location = error_expression.function.args[1].source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_location,
+ "Either all arguments to '-' and its result must fit in a 64-bit "
+ "unsigned integer, or all must fit in a 64-bit signed integer.",
+ ),
+ error.note(
+ "m.emb", arg0_location, "Requires unsigned 64-bit integer."
+ ),
+ error.note(
+ "m.emb", arg1_location, "Requires unsigned 64-bit integer."
+ ),
+ error.note(
+ "m.emb", error_location, "Requires signed 64-bit integer."
+ ),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_return_value_requires_different_signedness_from_one_argument(self):
- ir = _make_ir_from_emb(
- '[$default byte_order: "LittleEndian"]\n'
- "struct Foo:\n"
- " 0 [+1] UInt x\n"
- # One argument requires uint64; result fits in int64.
- " if (x + 0x7fff_ffff_ffff_fff0) - 0x7fff_ffff_ffff_ffff < 10:\n"
- " 1 [+1] UInt y\n")
- condition = ir.module[0].type[0].structure.field[1].existence_condition
- error_expression = condition.function.args[0]
- error_location = error_expression.source_location
- arg0_location = error_expression.function.args[0].source_location
- self.assertEqual([
- [error.error(
- "m.emb", error_location,
- "Either all arguments to '-' and its result must fit in a 64-bit "
- "unsigned integer, or all must fit in a 64-bit signed integer."),
- error.note("m.emb", arg0_location,
- "Requires unsigned 64-bit integer."),
- error.note("m.emb", error_location,
- "Requires signed 64-bit integer.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_return_value_requires_different_signedness_from_one_argument(self):
+ ir = _make_ir_from_emb(
+ '[$default byte_order: "LittleEndian"]\n'
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ # One argument requires uint64; result fits in int64.
+ " if (x + 0x7fff_ffff_ffff_fff0) - 0x7fff_ffff_ffff_ffff < 10:\n"
+ " 1 [+1] UInt y\n"
+ )
+ condition = ir.module[0].type[0].structure.field[1].existence_condition
+ error_expression = condition.function.args[0]
+ error_location = error_expression.source_location
+ arg0_location = error_expression.function.args[0].source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_location,
+ "Either all arguments to '-' and its result must fit in a 64-bit "
+ "unsigned integer, or all must fit in a 64-bit signed integer.",
+ ),
+ error.note(
+ "m.emb", arg0_location, "Requires unsigned 64-bit integer."
+ ),
+ error.note(
+ "m.emb", error_location, "Requires signed 64-bit integer."
+ ),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_checks_constancy_of_constant_references(self):
- ir = _make_ir_from_emb("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " let y = x\n"
- " let z = Foo.y\n")
- error_expression = ir.module[0].type[0].structure.field[2].read_transform
- error_location = error_expression.source_location
- note_field = ir.module[0].type[0].structure.field[1]
- note_location = note_field.source_location
- self.assertEqual([
- [error.error("m.emb", error_location,
- "Static references must refer to constants."),
- error.note("m.emb", note_location, "y is not constant.")]
- ], error.filter_errors(constraints.check_constraints(ir)))
+ def test_checks_constancy_of_constant_references(self):
+ ir = _make_ir_from_emb(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " let y = x\n" " let z = Foo.y\n"
+ )
+ error_expression = ir.module[0].type[0].structure.field[2].read_transform
+ error_location = error_expression.source_location
+ note_field = ir.module[0].type[0].structure.field[1]
+ note_location = note_field.source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_location,
+ "Static references must refer to constants.",
+ ),
+ error.note("m.emb", note_location, "y is not constant."),
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_checks_for_explicit_size_on_parameters(self):
- ir = _make_ir_from_emb("struct Foo(y: UInt):\n"
- " 0 [+1] UInt x\n")
- error_parameter = ir.module[0].type[0].runtime_parameter[0]
- error_location = error_parameter.physical_type_alias.source_location
- self.assertEqual(
- [[error.error("m.emb", error_location,
- "Integer range of parameter must not be unbounded; it "
- "must fit in a 64-bit signed or unsigned integer.")]],
- error.filter_errors(constraints.check_constraints(ir)))
+ def test_checks_for_explicit_size_on_parameters(self):
+ ir = _make_ir_from_emb("struct Foo(y: UInt):\n" " 0 [+1] UInt x\n")
+ error_parameter = ir.module[0].type[0].runtime_parameter[0]
+ error_location = error_parameter.physical_type_alias.source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_location,
+ "Integer range of parameter must not be unbounded; it "
+ "must fit in a 64-bit signed or unsigned integer.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_checks_for_correct_explicit_size_on_parameters(self):
- ir = _make_ir_from_emb("struct Foo(y: UInt:300):\n"
- " 0 [+1] UInt x\n")
- error_parameter = ir.module[0].type[0].runtime_parameter[0]
- error_location = error_parameter.physical_type_alias.source_location
- self.assertEqual(
- [[error.error("m.emb", error_location,
- "Potential range of parameter is 0 to {}, which cannot "
- "fit in a 64-bit signed or unsigned integer.".format(
- 2**300-1))]],
- error.filter_errors(constraints.check_constraints(ir)))
+ def test_checks_for_correct_explicit_size_on_parameters(self):
+ ir = _make_ir_from_emb("struct Foo(y: UInt:300):\n" " 0 [+1] UInt x\n")
+ error_parameter = ir.module[0].type[0].runtime_parameter[0]
+ error_location = error_parameter.physical_type_alias.source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_location,
+ "Potential range of parameter is 0 to {}, which cannot "
+ "fit in a 64-bit signed or unsigned integer.".format(
+ 2**300 - 1
+ ),
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
- def test_checks_for_explicit_enum_size_on_parameters(self):
- ir = _make_ir_from_emb("struct Foo(y: Bar:8):\n"
- " 0 [+1] UInt x\n"
- "enum Bar:\n"
- " QUX = 1\n")
- error_parameter = ir.module[0].type[0].runtime_parameter[0]
- error_size = error_parameter.physical_type_alias.size_in_bits
- error_location = error_size.source_location
- self.assertEqual(
- [[error.error(
- "m.emb", error_location,
- "Parameters with enum type may not have explicit size.")]],
- error.filter_errors(constraints.check_constraints(ir)))
+ def test_checks_for_explicit_enum_size_on_parameters(self):
+ ir = _make_ir_from_emb(
+ "struct Foo(y: Bar:8):\n" " 0 [+1] UInt x\n" "enum Bar:\n" " QUX = 1\n"
+ )
+ error_parameter = ir.module[0].type[0].runtime_parameter[0]
+ error_size = error_parameter.physical_type_alias.size_in_bits
+ error_location = error_size.source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ error_location,
+ "Parameters with enum type may not have explicit size.",
+ )
+ ]
+ ],
+ error.filter_errors(constraints.check_constraints(ir)),
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/dependency_checker.py b/compiler/front_end/dependency_checker.py
index 963d8bf..8a9e903 100644
--- a/compiler/front_end/dependency_checker.py
+++ b/compiler/front_end/dependency_checker.py
@@ -20,261 +20,302 @@
from compiler.util import traverse_ir
-def _add_reference_to_dependencies(reference, dependencies, name,
- source_file_name, errors):
- if reference.canonical_name.object_path[0] in {"$is_statically_sized",
- "$static_size_in_bits",
- "$next"}:
- # This error is a bit opaque, but given that the compiler used to crash on
- # this case -- for a couple of years -- and no one complained, it seems
- # safe to assume that this is a rare error.
- errors.append([
- error.error(source_file_name, reference.source_location,
- "Keyword `" + reference.canonical_name.object_path[0] +
- "` may not be used in this context."),
- ])
- return
- dependencies[name] |= {ir_util.hashable_form_of_reference(reference)}
+def _add_reference_to_dependencies(
+ reference, dependencies, name, source_file_name, errors
+):
+ if reference.canonical_name.object_path[0] in {
+ "$is_statically_sized",
+ "$static_size_in_bits",
+ "$next",
+ }:
+ # This error is a bit opaque, but given that the compiler used to crash on
+ # this case -- for a couple of years -- and no one complained, it seems
+ # safe to assume that this is a rare error.
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ reference.source_location,
+ "Keyword `"
+ + reference.canonical_name.object_path[0]
+ + "` may not be used in this context.",
+ ),
+ ]
+ )
+ return
+ dependencies[name] |= {ir_util.hashable_form_of_reference(reference)}
def _add_field_reference_to_dependencies(reference, dependencies, name):
- dependencies[name] |= {ir_util.hashable_form_of_reference(reference.path[0])}
+ dependencies[name] |= {ir_util.hashable_form_of_reference(reference.path[0])}
def _add_name_to_dependencies(proto, dependencies):
- name = ir_util.hashable_form_of_reference(proto.name)
- dependencies.setdefault(name, set())
- return {"name": name}
+ name = ir_util.hashable_form_of_reference(proto.name)
+ dependencies.setdefault(name, set())
+ return {"name": name}
def _find_dependencies(ir):
- """Constructs a dependency graph for the entire IR."""
- dependencies = {}
- errors = []
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Reference], _add_reference_to_dependencies,
- # TODO(bolms): Add handling for references inside of attributes, once
- # there are attributes with non-constant values.
- skip_descendants_of={
- ir_data.AtomicType, ir_data.Attribute, ir_data.FieldReference
- },
- incidental_actions={
- ir_data.Field: _add_name_to_dependencies,
- ir_data.EnumValue: _add_name_to_dependencies,
- ir_data.RuntimeParameter: _add_name_to_dependencies,
- },
- parameters={
- "dependencies": dependencies,
- "errors": errors,
- })
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.FieldReference], _add_field_reference_to_dependencies,
- skip_descendants_of={ir_data.Attribute},
- incidental_actions={
- ir_data.Field: _add_name_to_dependencies,
- ir_data.EnumValue: _add_name_to_dependencies,
- ir_data.RuntimeParameter: _add_name_to_dependencies,
- },
- parameters={"dependencies": dependencies})
- return dependencies, errors
+ """Constructs a dependency graph for the entire IR."""
+ dependencies = {}
+ errors = []
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Reference],
+ _add_reference_to_dependencies,
+ # TODO(bolms): Add handling for references inside of attributes, once
+ # there are attributes with non-constant values.
+ skip_descendants_of={
+ ir_data.AtomicType,
+ ir_data.Attribute,
+ ir_data.FieldReference,
+ },
+ incidental_actions={
+ ir_data.Field: _add_name_to_dependencies,
+ ir_data.EnumValue: _add_name_to_dependencies,
+ ir_data.RuntimeParameter: _add_name_to_dependencies,
+ },
+ parameters={
+ "dependencies": dependencies,
+ "errors": errors,
+ },
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.FieldReference],
+ _add_field_reference_to_dependencies,
+ skip_descendants_of={ir_data.Attribute},
+ incidental_actions={
+ ir_data.Field: _add_name_to_dependencies,
+ ir_data.EnumValue: _add_name_to_dependencies,
+ ir_data.RuntimeParameter: _add_name_to_dependencies,
+ },
+ parameters={"dependencies": dependencies},
+ )
+ return dependencies, errors
def _find_dependency_ordering_for_fields_in_structure(
- structure, type_definition, dependencies):
- """Populates structure.fields_in_dependency_order."""
- # For fields which appear before their dependencies in the original source
- # text, this algorithm moves them to immediately after their dependencies.
- #
- # This is one of many possible schemes for constructing a dependency ordering;
- # it has the advantage that all of the generated fields (e.g., $size_in_bytes)
- # stay at the end of the ordering, which makes testing easier.
- order = []
- added = set()
- for parameter in type_definition.runtime_parameter:
- added.add(ir_util.hashable_form_of_reference(parameter.name))
- needed = list(range(len(structure.field)))
- while True:
- for i in range(len(needed)):
- field_number = needed[i]
- field = ir_util.hashable_form_of_reference(
- structure.field[field_number].name)
- assert field in dependencies, "dependencies = {}".format(dependencies)
- if all(dependency in added for dependency in dependencies[field]):
- order.append(field_number)
- added.add(field)
- del needed[i]
- break
- else:
- break
- # If a non-local-field dependency were in dependencies[field], then not all
- # fields would be added to the dependency ordering. This shouldn't happen.
- assert len(order) == len(structure.field), (
- "order: {}\nlen(structure.field: {})".format(order, len(structure.field)))
- del structure.fields_in_dependency_order[:]
- structure.fields_in_dependency_order.extend(order)
+ structure, type_definition, dependencies
+):
+ """Populates structure.fields_in_dependency_order."""
+ # For fields which appear before their dependencies in the original source
+ # text, this algorithm moves them to immediately after their dependencies.
+ #
+ # This is one of many possible schemes for constructing a dependency ordering;
+ # it has the advantage that all of the generated fields (e.g., $size_in_bytes)
+ # stay at the end of the ordering, which makes testing easier.
+ order = []
+ added = set()
+ for parameter in type_definition.runtime_parameter:
+ added.add(ir_util.hashable_form_of_reference(parameter.name))
+ needed = list(range(len(structure.field)))
+ while True:
+ for i in range(len(needed)):
+ field_number = needed[i]
+ field = ir_util.hashable_form_of_reference(
+ structure.field[field_number].name
+ )
+ assert field in dependencies, "dependencies = {}".format(dependencies)
+ if all(dependency in added for dependency in dependencies[field]):
+ order.append(field_number)
+ added.add(field)
+ del needed[i]
+ break
+ else:
+ break
+ # If a non-local-field dependency were in dependencies[field], then not all
+ # fields would be added to the dependency ordering. This shouldn't happen.
+ assert len(order) == len(
+ structure.field
+ ), "order: {}\nlen(structure.field: {})".format(order, len(structure.field))
+ del structure.fields_in_dependency_order[:]
+ structure.fields_in_dependency_order.extend(order)
def _find_dependency_ordering_for_fields(ir):
- """Populates the fields_in_dependency_order fields throughout ir."""
- dependencies = {}
- # TODO(bolms): This duplicates work in _find_dependencies that could be
- # shared.
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.FieldReference], _add_field_reference_to_dependencies,
- skip_descendants_of={ir_data.Attribute},
- incidental_actions={
- ir_data.Field: _add_name_to_dependencies,
- ir_data.EnumValue: _add_name_to_dependencies,
- ir_data.RuntimeParameter: _add_name_to_dependencies,
- },
- parameters={"dependencies": dependencies})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Structure],
- _find_dependency_ordering_for_fields_in_structure,
- parameters={"dependencies": dependencies})
+ """Populates the fields_in_dependency_order fields throughout ir."""
+ dependencies = {}
+ # TODO(bolms): This duplicates work in _find_dependencies that could be
+ # shared.
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.FieldReference],
+ _add_field_reference_to_dependencies,
+ skip_descendants_of={ir_data.Attribute},
+ incidental_actions={
+ ir_data.Field: _add_name_to_dependencies,
+ ir_data.EnumValue: _add_name_to_dependencies,
+ ir_data.RuntimeParameter: _add_name_to_dependencies,
+ },
+ parameters={"dependencies": dependencies},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Structure],
+ _find_dependency_ordering_for_fields_in_structure,
+ parameters={"dependencies": dependencies},
+ )
def _find_module_import_dependencies(ir):
- """Constructs a dependency graph of module imports."""
- dependencies = {}
- for module in ir.module:
- foreign_imports = set()
- for foreign_import in module.foreign_import:
- # The prelude gets an automatic self-import that shouldn't cause any
- # problems. No other self-imports are allowed, however.
- if foreign_import.file_name.text or module.source_file_name:
- foreign_imports |= {(foreign_import.file_name.text,)}
- dependencies[module.source_file_name,] = foreign_imports
- return dependencies
+ """Constructs a dependency graph of module imports."""
+ dependencies = {}
+ for module in ir.module:
+ foreign_imports = set()
+ for foreign_import in module.foreign_import:
+ # The prelude gets an automatic self-import that shouldn't cause any
+ # problems. No other self-imports are allowed, however.
+ if foreign_import.file_name.text or module.source_file_name:
+ foreign_imports |= {(foreign_import.file_name.text,)}
+ dependencies[module.source_file_name,] = foreign_imports
+ return dependencies
def _find_cycles(graph):
- """Finds cycles in graph.
+ """Finds cycles in graph.
- The graph does not need to be fully connected.
+ The graph does not need to be fully connected.
- Arguments:
- graph: A dictionary whose keys are node labels. Values are sets of node
- labels, representing edges from the key node to the value nodes.
+ Arguments:
+ graph: A dictionary whose keys are node labels. Values are sets of node
+ labels, representing edges from the key node to the value nodes.
- Returns:
- A set of sets of nodes which form strongly-connected components (subgraphs
- where every node is directly or indirectly reachable from every other node).
- No node will be included in more than one strongly-connected component, by
- definition. Strongly-connected components of size 1, where the node in the
- component does not have a self-edge, are not included in the result.
+ Returns:
+ A set of sets of nodes which form strongly-connected components (subgraphs
+ where every node is directly or indirectly reachable from every other node).
+ No node will be included in more than one strongly-connected component, by
+ definition. Strongly-connected components of size 1, where the node in the
+ component does not have a self-edge, are not included in the result.
- Note that a strongly-connected component may have a more complex structure
- than a single loop. For example:
+ Note that a strongly-connected component may have a more complex structure
+ than a single loop. For example:
- +-- A <-+ +-> B --+
- | | | |
- v C v
- D ^ ^ E
- | | | |
- +-> F --+ +-- G <-+
- """
- # This uses Tarjan's strongly-connected components algorithm, as described by
- # Wikipedia. This is a depth-first traversal of the graph with a node stack
- # that is independent of the call stack; nodes are added to the stack when
- # they are first encountered, but not removed until all nodes they can reach
- # have been checked.
- next_index = [0]
- node_indices = {}
- node_lowlinks = {}
- nodes_on_stack = set()
- stack = []
- nontrivial_components = set()
+ +-- A <-+ +-> B --+
+ | | | |
+ v C v
+ D ^ ^ E
+ | | | |
+ +-> F --+ +-- G <-+
+ """
+ # This uses Tarjan's strongly-connected components algorithm, as described by
+ # Wikipedia. This is a depth-first traversal of the graph with a node stack
+ # that is independent of the call stack; nodes are added to the stack when
+ # they are first encountered, but not removed until all nodes they can reach
+ # have been checked.
+ next_index = [0]
+ node_indices = {}
+ node_lowlinks = {}
+ nodes_on_stack = set()
+ stack = []
+ nontrivial_components = set()
- def strong_connect(node):
- """Implements the STRONGCONNECT routine of Tarjan's algorithm."""
- node_indices[node] = next_index[0]
- node_lowlinks[node] = next_index[0]
- next_index[0] += 1
- stack.append(node)
- nodes_on_stack.add(node)
+ def strong_connect(node):
+ """Implements the STRONGCONNECT routine of Tarjan's algorithm."""
+ node_indices[node] = next_index[0]
+ node_lowlinks[node] = next_index[0]
+ next_index[0] += 1
+ stack.append(node)
+ nodes_on_stack.add(node)
- for destination_node in graph[node]:
- if destination_node not in node_indices:
- strong_connect(destination_node)
- node_lowlinks[node] = min(node_lowlinks[node],
- node_lowlinks[destination_node])
- elif destination_node in nodes_on_stack:
- node_lowlinks[node] = min(node_lowlinks[node],
- node_indices[destination_node])
+ for destination_node in graph[node]:
+ if destination_node not in node_indices:
+ strong_connect(destination_node)
+ node_lowlinks[node] = min(
+ node_lowlinks[node], node_lowlinks[destination_node]
+ )
+ elif destination_node in nodes_on_stack:
+ node_lowlinks[node] = min(
+ node_lowlinks[node], node_indices[destination_node]
+ )
- strongly_connected_component = []
- if node_lowlinks[node] == node_indices[node]:
- while True:
- popped_node = stack.pop()
- nodes_on_stack.remove(popped_node)
- strongly_connected_component.append(popped_node)
- if popped_node == node:
- break
- if (len(strongly_connected_component) > 1 or
- strongly_connected_component[0] in
- graph[strongly_connected_component[0]]):
- nontrivial_components.add(frozenset(strongly_connected_component))
+ strongly_connected_component = []
+ if node_lowlinks[node] == node_indices[node]:
+ while True:
+ popped_node = stack.pop()
+ nodes_on_stack.remove(popped_node)
+ strongly_connected_component.append(popped_node)
+ if popped_node == node:
+ break
+ if (
+ len(strongly_connected_component) > 1
+ or strongly_connected_component[0]
+ in graph[strongly_connected_component[0]]
+ ):
+ nontrivial_components.add(frozenset(strongly_connected_component))
- for node in graph:
- if node not in node_indices:
- strong_connect(node)
- return nontrivial_components
+ for node in graph:
+ if node not in node_indices:
+ strong_connect(node)
+ return nontrivial_components
def _find_object_dependency_cycles(ir):
- """Finds dependency cycles in types in the ir."""
- dependencies, find_dependency_errors = _find_dependencies(ir)
- if find_dependency_errors:
- return find_dependency_errors
- errors = []
- cycles = _find_cycles(dict(dependencies))
- for cycle in cycles:
- # TODO(bolms): This lists the entire strongly-connected component in a
- # fairly arbitrary order. This is simple, and handles components that
- # aren't simple cycles, but may not be the most user-friendly way to
- # present this information.
- cycle_list = sorted(list(cycle))
- node_object = ir_util.find_object(cycle_list[0], ir)
- error_group = [
- error.error(cycle_list[0][0], node_object.source_location,
- "Dependency cycle\n" + node_object.name.name.text)
- ]
- for node in cycle_list[1:]:
- node_object = ir_util.find_object(node, ir)
- error_group.append(error.note(node[0], node_object.source_location,
- node_object.name.name.text))
- errors.append(error_group)
- return errors
+ """Finds dependency cycles in types in the ir."""
+ dependencies, find_dependency_errors = _find_dependencies(ir)
+ if find_dependency_errors:
+ return find_dependency_errors
+ errors = []
+ cycles = _find_cycles(dict(dependencies))
+ for cycle in cycles:
+ # TODO(bolms): This lists the entire strongly-connected component in a
+ # fairly arbitrary order. This is simple, and handles components that
+ # aren't simple cycles, but may not be the most user-friendly way to
+ # present this information.
+ cycle_list = sorted(list(cycle))
+ node_object = ir_util.find_object(cycle_list[0], ir)
+ error_group = [
+ error.error(
+ cycle_list[0][0],
+ node_object.source_location,
+ "Dependency cycle\n" + node_object.name.name.text,
+ )
+ ]
+ for node in cycle_list[1:]:
+ node_object = ir_util.find_object(node, ir)
+ error_group.append(
+ error.note(
+ node[0], node_object.source_location, node_object.name.name.text
+ )
+ )
+ errors.append(error_group)
+ return errors
def _find_module_dependency_cycles(ir):
- """Finds dependency cycles in modules in the ir."""
- dependencies = _find_module_import_dependencies(ir)
- cycles = _find_cycles(dict(dependencies))
- errors = []
- for cycle in cycles:
- cycle_list = sorted(list(cycle))
- module = ir_util.find_object(cycle_list[0], ir)
- error_group = [
- error.error(cycle_list[0][0], module.source_location,
- "Import dependency cycle\n" + module.source_file_name)
- ]
- for module_name in cycle_list[1:]:
- module = ir_util.find_object(module_name, ir)
- error_group.append(error.note(module_name[0], module.source_location,
- module.source_file_name))
- errors.append(error_group)
- return errors
+ """Finds dependency cycles in modules in the ir."""
+ dependencies = _find_module_import_dependencies(ir)
+ cycles = _find_cycles(dict(dependencies))
+ errors = []
+ for cycle in cycles:
+ cycle_list = sorted(list(cycle))
+ module = ir_util.find_object(cycle_list[0], ir)
+ error_group = [
+ error.error(
+ cycle_list[0][0],
+ module.source_location,
+ "Import dependency cycle\n" + module.source_file_name,
+ )
+ ]
+ for module_name in cycle_list[1:]:
+ module = ir_util.find_object(module_name, ir)
+ error_group.append(
+ error.note(
+ module_name[0], module.source_location, module.source_file_name
+ )
+ )
+ errors.append(error_group)
+ return errors
def find_dependency_cycles(ir):
- """Finds any dependency cycles in the ir."""
- errors = _find_module_dependency_cycles(ir)
- return errors + _find_object_dependency_cycles(ir)
+ """Finds any dependency cycles in the ir."""
+ errors = _find_module_dependency_cycles(ir)
+ return errors + _find_object_dependency_cycles(ir)
def set_dependency_order(ir):
- """Sets the fields_in_dependency_order member of Structures."""
- _find_dependency_ordering_for_fields(ir)
- return []
+ """Sets the fields_in_dependency_order member of Structures."""
+ _find_dependency_ordering_for_fields(ir)
+ return []
diff --git a/compiler/front_end/dependency_checker_test.py b/compiler/front_end/dependency_checker_test.py
index 27af812..a9110da 100644
--- a/compiler/front_end/dependency_checker_test.py
+++ b/compiler/front_end/dependency_checker_test.py
@@ -22,287 +22,408 @@
def _parse_snippet(emb_file):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- "m.emb",
- test_util.dict_file_reader({"m.emb": emb_file}),
- stop_before_step="find_dependency_cycles")
- assert not errors
- return ir
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ "m.emb",
+ test_util.dict_file_reader({"m.emb": emb_file}),
+ stop_before_step="find_dependency_cycles",
+ )
+ assert not errors
+ return ir
def _find_dependencies_for_snippet(emb_file):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- "m.emb",
- test_util.dict_file_reader({
- "m.emb": emb_file
- }),
- stop_before_step="set_dependency_order")
- assert not errors, errors
- return ir
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ "m.emb",
+ test_util.dict_file_reader({"m.emb": emb_file}),
+ stop_before_step="set_dependency_order",
+ )
+ assert not errors, errors
+ return ir
class DependencyCheckerTest(unittest.TestCase):
- def test_error_on_simple_field_cycle(self):
- ir = _parse_snippet("struct Foo:\n"
- " 0 [+field2] UInt field1\n"
- " 0 [+field1] UInt field2\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nfield1"),
- error.note("m.emb", struct.field[1].source_location, "field2")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_simple_field_cycle(self):
+ ir = _parse_snippet(
+ "struct Foo:\n"
+ " 0 [+field2] UInt field1\n"
+ " 0 [+field1] UInt field2\n"
+ )
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ struct.field[0].source_location,
+ "Dependency cycle\nfield1",
+ ),
+ error.note("m.emb", struct.field[1].source_location, "field2"),
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_self_cycle(self):
- ir = _parse_snippet("struct Foo:\n"
- " 0 [+field1] UInt field1\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nfield1")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_self_cycle(self):
+ ir = _parse_snippet("struct Foo:\n" " 0 [+field1] UInt field1\n")
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ struct.field[0].source_location,
+ "Dependency cycle\nfield1",
+ )
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_triple_field_cycle(self):
- ir = _parse_snippet("struct Foo:\n"
- " 0 [+field2] UInt field1\n"
- " 0 [+field3] UInt field2\n"
- " 0 [+field1] UInt field3\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nfield1"),
- error.note("m.emb", struct.field[1].source_location, "field2"),
- error.note("m.emb", struct.field[2].source_location, "field3"),
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_triple_field_cycle(self):
+ ir = _parse_snippet(
+ "struct Foo:\n"
+ " 0 [+field2] UInt field1\n"
+ " 0 [+field3] UInt field2\n"
+ " 0 [+field1] UInt field3\n"
+ )
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ struct.field[0].source_location,
+ "Dependency cycle\nfield1",
+ ),
+ error.note("m.emb", struct.field[1].source_location, "field2"),
+ error.note("m.emb", struct.field[2].source_location, "field3"),
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_complex_field_cycle(self):
- ir = _parse_snippet("struct Foo:\n"
- " 0 [+field2] UInt field1\n"
- " 0 [+field3+field4] UInt field2\n"
- " 0 [+field1] UInt field3\n"
- " 0 [+field2] UInt field4\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nfield1"),
- error.note("m.emb", struct.field[1].source_location, "field2"),
- error.note("m.emb", struct.field[2].source_location, "field3"),
- error.note("m.emb", struct.field[3].source_location, "field4"),
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_complex_field_cycle(self):
+ ir = _parse_snippet(
+ "struct Foo:\n"
+ " 0 [+field2] UInt field1\n"
+ " 0 [+field3+field4] UInt field2\n"
+ " 0 [+field1] UInt field3\n"
+ " 0 [+field2] UInt field4\n"
+ )
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ struct.field[0].source_location,
+ "Dependency cycle\nfield1",
+ ),
+ error.note("m.emb", struct.field[1].source_location, "field2"),
+ error.note("m.emb", struct.field[2].source_location, "field3"),
+ error.note("m.emb", struct.field[3].source_location, "field4"),
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_simple_enum_value_cycle(self):
- ir = _parse_snippet("enum Foo:\n"
- " XX = YY\n"
- " YY = XX\n")
- enum = ir.module[0].type[0].enumeration
- self.assertEqual([[
- error.error("m.emb", enum.value[0].source_location,
- "Dependency cycle\nXX"),
- error.note("m.emb", enum.value[1].source_location, "YY")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_simple_enum_value_cycle(self):
+ ir = _parse_snippet("enum Foo:\n" " XX = YY\n" " YY = XX\n")
+ enum = ir.module[0].type[0].enumeration
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb", enum.value[0].source_location, "Dependency cycle\nXX"
+ ),
+ error.note("m.emb", enum.value[1].source_location, "YY"),
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_no_error_on_no_cycle(self):
- ir = _parse_snippet("enum Foo:\n"
- " XX = 0\n"
- " YY = XX\n")
- self.assertEqual([], dependency_checker.find_dependency_cycles(ir))
+ def test_no_error_on_no_cycle(self):
+ ir = _parse_snippet("enum Foo:\n" " XX = 0\n" " YY = XX\n")
+ self.assertEqual([], dependency_checker.find_dependency_cycles(ir))
- def test_error_on_cycle_nested(self):
- ir = _parse_snippet("struct Foo:\n"
- " struct Bar:\n"
- " 0 [+field2] UInt field1\n"
- " 0 [+field1] UInt field2\n"
- " 0 [+1] UInt field\n")
- struct = ir.module[0].type[0].subtype[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nfield1"),
- error.note("m.emb", struct.field[1].source_location, "field2")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_cycle_nested(self):
+ ir = _parse_snippet(
+ "struct Foo:\n"
+ " struct Bar:\n"
+ " 0 [+field2] UInt field1\n"
+ " 0 [+field1] UInt field2\n"
+ " 0 [+1] UInt field\n"
+ )
+ struct = ir.module[0].type[0].subtype[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ struct.field[0].source_location,
+ "Dependency cycle\nfield1",
+ ),
+ error.note("m.emb", struct.field[1].source_location, "field2"),
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_import_cycle(self):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- "m.emb",
- test_util.dict_file_reader({"m.emb": 'import "n.emb" as n\n',
- "n.emb": 'import "m.emb" as m\n'}),
- stop_before_step="find_dependency_cycles")
- assert not errors
- self.assertEqual([[
- error.error("m.emb", ir.module[0].source_location,
- "Import dependency cycle\nm.emb"),
- error.note("n.emb", ir.module[2].source_location, "n.emb")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_import_cycle(self):
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ "m.emb",
+ test_util.dict_file_reader(
+ {"m.emb": 'import "n.emb" as n\n', "n.emb": 'import "m.emb" as m\n'}
+ ),
+ stop_before_step="find_dependency_cycles",
+ )
+ assert not errors
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ ir.module[0].source_location,
+ "Import dependency cycle\nm.emb",
+ ),
+ error.note("n.emb", ir.module[2].source_location, "n.emb"),
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_import_cycle_and_field_cycle(self):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- "m.emb",
- test_util.dict_file_reader({"m.emb": 'import "n.emb" as n\n'
- "struct Foo:\n"
- " 0 [+field1] UInt field1\n",
- "n.emb": 'import "m.emb" as m\n'}),
- stop_before_step="find_dependency_cycles")
- assert not errors
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", ir.module[0].source_location,
- "Import dependency cycle\nm.emb"),
- error.note("n.emb", ir.module[2].source_location, "n.emb")
- ], [
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nfield1")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_import_cycle_and_field_cycle(self):
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ "m.emb",
+ test_util.dict_file_reader(
+ {
+ "m.emb": 'import "n.emb" as n\n'
+ "struct Foo:\n"
+ " 0 [+field1] UInt field1\n",
+ "n.emb": 'import "m.emb" as m\n',
+ }
+ ),
+ stop_before_step="find_dependency_cycles",
+ )
+ assert not errors
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ ir.module[0].source_location,
+ "Import dependency cycle\nm.emb",
+ ),
+ error.note("n.emb", ir.module[2].source_location, "n.emb"),
+ ],
+ [
+ error.error(
+ "m.emb",
+ struct.field[0].source_location,
+ "Dependency cycle\nfield1",
+ )
+ ],
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_field_existence_self_cycle(self):
- ir = _parse_snippet("struct Foo:\n"
- " if x == 1:\n"
- " 0 [+1] UInt x\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nx")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_field_existence_self_cycle(self):
+ ir = _parse_snippet("struct Foo:\n" " if x == 1:\n" " 0 [+1] UInt x\n")
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb", struct.field[0].source_location, "Dependency cycle\nx"
+ )
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_field_existence_cycle(self):
- ir = _parse_snippet("struct Foo:\n"
- " if y == 1:\n"
- " 0 [+1] UInt x\n"
- " if x == 0:\n"
- " 1 [+1] UInt y\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nx"),
- error.note("m.emb", struct.field[1].source_location, "y")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_field_existence_cycle(self):
+ ir = _parse_snippet(
+ "struct Foo:\n"
+ " if y == 1:\n"
+ " 0 [+1] UInt x\n"
+ " if x == 0:\n"
+ " 1 [+1] UInt y\n"
+ )
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb", struct.field[0].source_location, "Dependency cycle\nx"
+ ),
+ error.note("m.emb", struct.field[1].source_location, "y"),
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_virtual_field_cycle(self):
- ir = _parse_snippet("struct Foo:\n"
- " let x = y\n"
- " let y = x\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nx"),
- error.note("m.emb", struct.field[1].source_location, "y")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_virtual_field_cycle(self):
+ ir = _parse_snippet("struct Foo:\n" " let x = y\n" " let y = x\n")
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb", struct.field[0].source_location, "Dependency cycle\nx"
+ ),
+ error.note("m.emb", struct.field[1].source_location, "y"),
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_virtual_non_virtual_field_cycle(self):
- ir = _parse_snippet("struct Foo:\n"
- " let x = y\n"
- " x [+4] UInt y\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nx"),
- error.note("m.emb", struct.field[1].source_location, "y")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_virtual_non_virtual_field_cycle(self):
+ ir = _parse_snippet("struct Foo:\n" " let x = y\n" " x [+4] UInt y\n")
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb", struct.field[0].source_location, "Dependency cycle\nx"
+ ),
+ error.note("m.emb", struct.field[1].source_location, "y"),
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_non_virtual_virtual_field_cycle(self):
- ir = _parse_snippet("struct Foo:\n"
- " y [+4] UInt x\n"
- " let y = x\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nx"),
- error.note("m.emb", struct.field[1].source_location, "y")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_non_virtual_virtual_field_cycle(self):
+ ir = _parse_snippet("struct Foo:\n" " y [+4] UInt x\n" " let y = x\n")
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb", struct.field[0].source_location, "Dependency cycle\nx"
+ ),
+ error.note("m.emb", struct.field[1].source_location, "y"),
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_error_on_cycle_involving_subfield(self):
- ir = _parse_snippet("struct Bar:\n"
- " foo_b.x [+4] Foo foo_a\n"
- " foo_a.x [+4] Foo foo_b\n"
- "struct Foo:\n"
- " 0 [+4] UInt x\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].source_location,
- "Dependency cycle\nfoo_a"),
- error.note("m.emb", struct.field[1].source_location, "foo_b")
- ]], dependency_checker.find_dependency_cycles(ir))
+ def test_error_on_cycle_involving_subfield(self):
+ ir = _parse_snippet(
+ "struct Bar:\n"
+ " foo_b.x [+4] Foo foo_a\n"
+ " foo_a.x [+4] Foo foo_b\n"
+ "struct Foo:\n"
+ " 0 [+4] UInt x\n"
+ )
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ struct.field[0].source_location,
+ "Dependency cycle\nfoo_a",
+ ),
+ error.note("m.emb", struct.field[1].source_location, "foo_b"),
+ ]
+ ],
+ dependency_checker.find_dependency_cycles(ir),
+ )
- def test_dependency_ordering_with_no_dependencies(self):
- ir = _find_dependencies_for_snippet("struct Foo:\n"
- " 0 [+4] UInt a\n"
- " 4 [+4] UInt b\n")
- self.assertEqual([], dependency_checker.set_dependency_order(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([0, 1], struct.fields_in_dependency_order[:2])
+ def test_dependency_ordering_with_no_dependencies(self):
+ ir = _find_dependencies_for_snippet(
+ "struct Foo:\n" " 0 [+4] UInt a\n" " 4 [+4] UInt b\n"
+ )
+ self.assertEqual([], dependency_checker.set_dependency_order(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual([0, 1], struct.fields_in_dependency_order[:2])
- def test_dependency_ordering_with_dependency_in_order(self):
- ir = _find_dependencies_for_snippet("struct Foo:\n"
- " 0 [+4] UInt a\n"
- " a [+4] UInt b\n")
- self.assertEqual([], dependency_checker.set_dependency_order(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([0, 1], struct.fields_in_dependency_order[:2])
+ def test_dependency_ordering_with_dependency_in_order(self):
+ ir = _find_dependencies_for_snippet(
+ "struct Foo:\n" " 0 [+4] UInt a\n" " a [+4] UInt b\n"
+ )
+ self.assertEqual([], dependency_checker.set_dependency_order(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual([0, 1], struct.fields_in_dependency_order[:2])
- def test_dependency_ordering_with_dependency_in_reverse_order(self):
- ir = _find_dependencies_for_snippet("struct Foo:\n"
- " b [+4] UInt a\n"
- " 0 [+4] UInt b\n")
- self.assertEqual([], dependency_checker.set_dependency_order(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([1, 0], struct.fields_in_dependency_order[:2])
+ def test_dependency_ordering_with_dependency_in_reverse_order(self):
+ ir = _find_dependencies_for_snippet(
+ "struct Foo:\n" " b [+4] UInt a\n" " 0 [+4] UInt b\n"
+ )
+ self.assertEqual([], dependency_checker.set_dependency_order(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual([1, 0], struct.fields_in_dependency_order[:2])
- def test_dependency_ordering_with_extra_fields(self):
- ir = _find_dependencies_for_snippet("struct Foo:\n"
- " d [+4] UInt a\n"
- " 4 [+4] UInt b\n"
- " 8 [+4] UInt c\n"
- " 12 [+4] UInt d\n")
- self.assertEqual([], dependency_checker.set_dependency_order(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([1, 2, 3, 0], struct.fields_in_dependency_order[:4])
+ def test_dependency_ordering_with_extra_fields(self):
+ ir = _find_dependencies_for_snippet(
+ "struct Foo:\n"
+ " d [+4] UInt a\n"
+ " 4 [+4] UInt b\n"
+ " 8 [+4] UInt c\n"
+ " 12 [+4] UInt d\n"
+ )
+ self.assertEqual([], dependency_checker.set_dependency_order(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual([1, 2, 3, 0], struct.fields_in_dependency_order[:4])
- def test_dependency_ordering_scrambled(self):
- ir = _find_dependencies_for_snippet("struct Foo:\n"
- " d [+4] UInt a\n"
- " c [+4] UInt b\n"
- " a [+4] UInt c\n"
- " 12 [+4] UInt d\n")
- self.assertEqual([], dependency_checker.set_dependency_order(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([3, 0, 2, 1], struct.fields_in_dependency_order[:4])
+ def test_dependency_ordering_scrambled(self):
+ ir = _find_dependencies_for_snippet(
+ "struct Foo:\n"
+ " d [+4] UInt a\n"
+ " c [+4] UInt b\n"
+ " a [+4] UInt c\n"
+ " 12 [+4] UInt d\n"
+ )
+ self.assertEqual([], dependency_checker.set_dependency_order(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual([3, 0, 2, 1], struct.fields_in_dependency_order[:4])
- def test_dependency_ordering_multiple_dependents(self):
- ir = _find_dependencies_for_snippet("struct Foo:\n"
- " d [+4] UInt a\n"
- " d [+4] UInt b\n"
- " d [+4] UInt c\n"
- " 12 [+4] UInt d\n")
- self.assertEqual([], dependency_checker.set_dependency_order(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([3, 0, 1, 2], struct.fields_in_dependency_order[:4])
+ def test_dependency_ordering_multiple_dependents(self):
+ ir = _find_dependencies_for_snippet(
+ "struct Foo:\n"
+ " d [+4] UInt a\n"
+ " d [+4] UInt b\n"
+ " d [+4] UInt c\n"
+ " 12 [+4] UInt d\n"
+ )
+ self.assertEqual([], dependency_checker.set_dependency_order(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual([3, 0, 1, 2], struct.fields_in_dependency_order[:4])
- def test_dependency_ordering_multiple_dependencies(self):
- ir = _find_dependencies_for_snippet("struct Foo:\n"
- " b+c [+4] UInt a\n"
- " 4 [+4] UInt b\n"
- " 8 [+4] UInt c\n"
- " a [+4] UInt d\n")
- self.assertEqual([], dependency_checker.set_dependency_order(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([1, 2, 0, 3], struct.fields_in_dependency_order[:4])
+ def test_dependency_ordering_multiple_dependencies(self):
+ ir = _find_dependencies_for_snippet(
+ "struct Foo:\n"
+ " b+c [+4] UInt a\n"
+ " 4 [+4] UInt b\n"
+ " 8 [+4] UInt c\n"
+ " a [+4] UInt d\n"
+ )
+ self.assertEqual([], dependency_checker.set_dependency_order(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual([1, 2, 0, 3], struct.fields_in_dependency_order[:4])
- def test_dependency_ordering_with_parameter(self):
- ir = _find_dependencies_for_snippet("struct Foo:\n"
- " 0 [+1] Bar(x) b\n"
- " 1 [+1] UInt x\n"
- "struct Bar(x: UInt:8):\n"
- " x [+1] UInt y\n")
- self.assertEqual([], dependency_checker.set_dependency_order(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([1, 0], struct.fields_in_dependency_order[:2])
+ def test_dependency_ordering_with_parameter(self):
+ ir = _find_dependencies_for_snippet(
+ "struct Foo:\n"
+ " 0 [+1] Bar(x) b\n"
+ " 1 [+1] UInt x\n"
+ "struct Bar(x: UInt:8):\n"
+ " x [+1] UInt y\n"
+ )
+ self.assertEqual([], dependency_checker.set_dependency_order(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual([1, 0], struct.fields_in_dependency_order[:2])
- def test_dependency_ordering_with_local_parameter(self):
- ir = _find_dependencies_for_snippet("struct Foo(x: Int:13):\n"
- " 0 [+x] Int b\n")
- self.assertEqual([], dependency_checker.set_dependency_order(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([0], struct.fields_in_dependency_order[:1])
+ def test_dependency_ordering_with_local_parameter(self):
+ ir = _find_dependencies_for_snippet(
+ "struct Foo(x: Int:13):\n" " 0 [+x] Int b\n"
+ )
+ self.assertEqual([], dependency_checker.set_dependency_order(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual([0], struct.fields_in_dependency_order[:1])
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/docs_are_up_to_date_test.py b/compiler/front_end/docs_are_up_to_date_test.py
index 70fd4b3..00ab32a 100644
--- a/compiler/front_end/docs_are_up_to_date_test.py
+++ b/compiler/front_end/docs_are_up_to_date_test.py
@@ -21,22 +21,22 @@
class DocsAreUpToDateTest(unittest.TestCase):
- """Tests that auto-generated, checked-in documentation is up to date."""
+ """Tests that auto-generated, checked-in documentation is up to date."""
- def test_grammar_md(self):
- doc_md = pkgutil.get_data("doc", "grammar.md").decode(encoding="UTF-8")
- correct_md = generate_grammar_md.generate_grammar_md()
- # If this fails, run:
- #
- # bazel run //compiler/front_end:generate_grammar_md > doc/grammar.md
- #
- # Be sure to check that the results look good before committing!
- doc_md_lines = doc_md.splitlines()
- correct_md_lines = correct_md.splitlines()
- for i in range(len(doc_md_lines)):
- self.assertEqual(correct_md_lines[i], doc_md_lines[i])
- self.assertEqual(correct_md, doc_md)
+ def test_grammar_md(self):
+ doc_md = pkgutil.get_data("doc", "grammar.md").decode(encoding="UTF-8")
+ correct_md = generate_grammar_md.generate_grammar_md()
+ # If this fails, run:
+ #
+ # bazel run //compiler/front_end:generate_grammar_md > doc/grammar.md
+ #
+ # Be sure to check that the results look good before committing!
+ doc_md_lines = doc_md.splitlines()
+ correct_md_lines = correct_md.splitlines()
+ for i in range(len(doc_md_lines)):
+ self.assertEqual(correct_md_lines[i], doc_md_lines[i])
+ self.assertEqual(correct_md, doc_md)
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/emboss_front_end.py b/compiler/front_end/emboss_front_end.py
index 1e30ded..c62638d 100644
--- a/compiler/front_end/emboss_front_end.py
+++ b/compiler/front_end/emboss_front_end.py
@@ -34,157 +34,180 @@
def _parse_command_line(argv):
- """Parses the given command-line arguments."""
- parser = argparse.ArgumentParser(description="Emboss compiler front end.",
- prog=argv[0])
- parser.add_argument("input_file",
- type=str,
- nargs=1,
- help=".emb file to compile.")
- parser.add_argument("--debug-show-tokenization",
- action="store_true",
- help="Show the tokenization of the main input file.")
- parser.add_argument("--debug-show-parse-tree",
- action="store_true",
- help="Show the parse tree of the main input file.")
- parser.add_argument("--debug-show-module-ir",
- action="store_true",
- help="Show the module-level IR of the main input file "
- "before symbol resolution.")
- parser.add_argument("--debug-show-full-ir",
- action="store_true",
- help="Show the final IR of the main input file.")
- parser.add_argument("--debug-show-used-productions",
- action="store_true",
- help="Show all of the grammar productions used in "
- "parsing the main input file.")
- parser.add_argument("--debug-show-unused-productions",
- action="store_true",
- help="Show all of the grammar productions not used in "
- "parsing the main input file.")
- parser.add_argument("--output-ir-to-stdout",
- action="store_true",
- help="Dump serialized IR to stdout.")
- parser.add_argument("--output-file",
- type=str,
- help="Write serialized IR to file.")
- parser.add_argument("--no-debug-show-header-lines",
- dest="debug_show_header_lines",
- action="store_false",
- help="Print header lines before output if true.")
- parser.add_argument("--color-output",
- default="if_tty",
- choices=["always", "never", "if_tty", "auto"],
- help="Print error messages using color. 'auto' is a "
- "synonym for 'if_tty'.")
- parser.add_argument("--import-dir", "-I",
- dest="import_dirs",
- action="append",
- default=["."],
- help="A directory to use when searching for imported "
- "embs. If no import_dirs are specified, the "
- "current directory will be used.")
- return parser.parse_args(argv[1:])
+ """Parses the given command-line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Emboss compiler front end.", prog=argv[0]
+ )
+ parser.add_argument("input_file", type=str, nargs=1, help=".emb file to compile.")
+ parser.add_argument(
+ "--debug-show-tokenization",
+ action="store_true",
+ help="Show the tokenization of the main input file.",
+ )
+ parser.add_argument(
+ "--debug-show-parse-tree",
+ action="store_true",
+ help="Show the parse tree of the main input file.",
+ )
+ parser.add_argument(
+ "--debug-show-module-ir",
+ action="store_true",
+ help="Show the module-level IR of the main input file "
+ "before symbol resolution.",
+ )
+ parser.add_argument(
+ "--debug-show-full-ir",
+ action="store_true",
+ help="Show the final IR of the main input file.",
+ )
+ parser.add_argument(
+ "--debug-show-used-productions",
+ action="store_true",
+ help="Show all of the grammar productions used in "
+ "parsing the main input file.",
+ )
+ parser.add_argument(
+ "--debug-show-unused-productions",
+ action="store_true",
+ help="Show all of the grammar productions not used in "
+ "parsing the main input file.",
+ )
+ parser.add_argument(
+ "--output-ir-to-stdout",
+ action="store_true",
+ help="Dump serialized IR to stdout.",
+ )
+ parser.add_argument("--output-file", type=str, help="Write serialized IR to file.")
+ parser.add_argument(
+ "--no-debug-show-header-lines",
+ dest="debug_show_header_lines",
+ action="store_false",
+ help="Print header lines before output if true.",
+ )
+ parser.add_argument(
+ "--color-output",
+ default="if_tty",
+ choices=["always", "never", "if_tty", "auto"],
+ help="Print error messages using color. 'auto' is a " "synonym for 'if_tty'.",
+ )
+ parser.add_argument(
+ "--import-dir",
+ "-I",
+ dest="import_dirs",
+ action="append",
+ default=["."],
+ help="A directory to use when searching for imported "
+ "embs. If no import_dirs are specified, the "
+ "current directory will be used.",
+ )
+ return parser.parse_args(argv[1:])
def _show_errors(errors, ir, color_output):
- """Prints errors with source code snippets."""
- source_codes = {}
- if ir:
- for module in ir.module:
- source_codes[module.source_file_name] = module.source_text
- use_color = (color_output == "always" or
- (color_output in ("auto", "if_tty") and
- os.isatty(sys.stderr.fileno())))
- print(error.format_errors(errors, source_codes, use_color), file=sys.stderr)
+ """Prints errors with source code snippets."""
+ source_codes = {}
+ if ir:
+ for module in ir.module:
+ source_codes[module.source_file_name] = module.source_text
+ use_color = color_output == "always" or (
+ color_output in ("auto", "if_tty") and os.isatty(sys.stderr.fileno())
+ )
+ print(error.format_errors(errors, source_codes, use_color), file=sys.stderr)
def _find_in_dirs_and_read(import_dirs):
- """Returns a function which will search import_dirs for a file."""
+ """Returns a function which will search import_dirs for a file."""
- def _find_and_read(file_name):
- """Searches import_dirs for file_name and returns the contents."""
- errors = []
- # *All* source files, including the one specified on the command line, will
- # be searched for in the import_dirs. This may be surprising, especially if
- # the current directory is *not* an import_dir.
- # TODO(bolms): Determine if this is really the desired behavior.
- for import_dir in import_dirs:
- full_name = path.join(import_dir, file_name)
- try:
- with open(full_name) as f:
- # As written, this follows the typical compiler convention of checking
- # the include/import directories in the order specified by the user,
- # and always reading the first matching file, even if other files
- # might match in later directories. This lets files shadow other
- # files, which can be useful in some cases (to override things), but
- # can also cause accidental shadowing, which can be tricky to fix.
- #
- # TODO(bolms): Check if any other files with the same name are in the
- # import path, and give a warning or error?
- return f.read(), None
- except IOError as e:
- errors.append(str(e))
- return None, errors + ["import path " + ":".join(import_dirs)]
+ def _find_and_read(file_name):
+ """Searches import_dirs for file_name and returns the contents."""
+ errors = []
+ # *All* source files, including the one specified on the command line, will
+ # be searched for in the import_dirs. This may be surprising, especially if
+ # the current directory is *not* an import_dir.
+ # TODO(bolms): Determine if this is really the desired behavior.
+ for import_dir in import_dirs:
+ full_name = path.join(import_dir, file_name)
+ try:
+ with open(full_name) as f:
+ # As written, this follows the typical compiler convention of checking
+ # the include/import directories in the order specified by the user,
+ # and always reading the first matching file, even if other files
+ # might match in later directories. This lets files shadow other
+ # files, which can be useful in some cases (to override things), but
+ # can also cause accidental shadowing, which can be tricky to fix.
+ #
+ # TODO(bolms): Check if any other files with the same name are in the
+ # import path, and give a warning or error?
+ return f.read(), None
+ except IOError as e:
+ errors.append(str(e))
+ return None, errors + ["import path " + ":".join(import_dirs)]
- return _find_and_read
+ return _find_and_read
+
def parse_and_log_errors(input_file, import_dirs, color_output):
- """Fully parses an .emb and logs any errors.
+ """Fully parses an .emb and logs any errors.
- Arguments:
- input_file: The path of the module source file.
- import_dirs: Directories to search for imported dependencies.
- color_output: Used when logging errors: "always", "never", "if_tty", "auto"
+ Arguments:
+ input_file: The path of the module source file.
+ import_dirs: Directories to search for imported dependencies.
+ color_output: Used when logging errors: "always", "never", "if_tty", "auto"
- Returns:
- (ir, debug_info, errors)
- """
- ir, debug_info, errors = glue.parse_emboss_file(
- input_file, _find_in_dirs_and_read(import_dirs))
- if errors:
- _show_errors(errors, ir, color_output)
+ Returns:
+ (ir, debug_info, errors)
+ """
+ ir, debug_info, errors = glue.parse_emboss_file(
+ input_file, _find_in_dirs_and_read(import_dirs)
+ )
+ if errors:
+ _show_errors(errors, ir, color_output)
- return (ir, debug_info, errors)
+ return (ir, debug_info, errors)
+
def main(flags):
- ir, debug_info, errors = parse_and_log_errors(
- flags.input_file[0], flags.import_dirs, flags.color_output)
- if errors:
- return 1
- main_module_debug_info = debug_info.modules[flags.input_file[0]]
- if flags.debug_show_tokenization:
- if flags.debug_show_header_lines:
- print("Tokenization:")
- print(main_module_debug_info.format_tokenization())
- if flags.debug_show_parse_tree:
- if flags.debug_show_header_lines:
- print("Parse Tree:")
- print(main_module_debug_info.format_parse_tree())
- if flags.debug_show_module_ir:
- if flags.debug_show_header_lines:
- print("Module IR:")
- print(main_module_debug_info.format_module_ir())
- if flags.debug_show_full_ir:
- if flags.debug_show_header_lines:
- print("Full IR:")
- print(str(ir))
- if flags.debug_show_used_productions:
- if flags.debug_show_header_lines:
- print("Used Productions:")
- print(glue.format_production_set(main_module_debug_info.used_productions))
- if flags.debug_show_unused_productions:
- if flags.debug_show_header_lines:
- print("Unused Productions:")
- print(glue.format_production_set(
- set(module_ir.PRODUCTIONS) - main_module_debug_info.used_productions))
- if flags.output_ir_to_stdout:
- print(ir_data_utils.IrDataSerializer(ir).to_json())
- if flags.output_file:
- with open(flags.output_file, "w") as f:
- f.write(ir_data_utils.IrDataSerializer(ir).to_json())
- return 0
+ ir, debug_info, errors = parse_and_log_errors(
+ flags.input_file[0], flags.import_dirs, flags.color_output
+ )
+ if errors:
+ return 1
+ main_module_debug_info = debug_info.modules[flags.input_file[0]]
+ if flags.debug_show_tokenization:
+ if flags.debug_show_header_lines:
+ print("Tokenization:")
+ print(main_module_debug_info.format_tokenization())
+ if flags.debug_show_parse_tree:
+ if flags.debug_show_header_lines:
+ print("Parse Tree:")
+ print(main_module_debug_info.format_parse_tree())
+ if flags.debug_show_module_ir:
+ if flags.debug_show_header_lines:
+ print("Module IR:")
+ print(main_module_debug_info.format_module_ir())
+ if flags.debug_show_full_ir:
+ if flags.debug_show_header_lines:
+ print("Full IR:")
+ print(str(ir))
+ if flags.debug_show_used_productions:
+ if flags.debug_show_header_lines:
+ print("Used Productions:")
+ print(glue.format_production_set(main_module_debug_info.used_productions))
+ if flags.debug_show_unused_productions:
+ if flags.debug_show_header_lines:
+ print("Unused Productions:")
+ print(
+ glue.format_production_set(
+ set(module_ir.PRODUCTIONS) - main_module_debug_info.used_productions
+ )
+ )
+ if flags.output_ir_to_stdout:
+ print(ir_data_utils.IrDataSerializer(ir).to_json())
+ if flags.output_file:
+ with open(flags.output_file, "w") as f:
+ f.write(ir_data_utils.IrDataSerializer(ir).to_json())
+ return 0
if __name__ == "__main__":
- sys.exit(main(_parse_command_line(sys.argv)))
+ sys.exit(main(_parse_command_line(sys.argv)))
diff --git a/compiler/front_end/expression_bounds.py b/compiler/front_end/expression_bounds.py
index e364399..cca36ee 100644
--- a/compiler/front_end/expression_bounds.py
+++ b/compiler/front_end/expression_bounds.py
@@ -26,702 +26,756 @@
# Create a local alias for math.gcd with a fallback to fractions.gcd if it is
# not available. This can be dropped if pre-3.5 Python support is dropped.
-if hasattr(math, 'gcd'):
- _math_gcd = math.gcd
+if hasattr(math, "gcd"):
+ _math_gcd = math.gcd
else:
- _math_gcd = fractions.gcd
+ _math_gcd = fractions.gcd
def compute_constraints_of_expression(expression, ir):
- """Adds appropriate bounding constraints to the given expression."""
- if ir_util.is_constant_type(expression.type):
- return
- expression_variety = expression.WhichOneof("expression")
- if expression_variety == "constant":
- _compute_constant_value_of_constant(expression)
- elif expression_variety == "constant_reference":
- _compute_constant_value_of_constant_reference(expression, ir)
- elif expression_variety == "function":
- _compute_constraints_of_function(expression, ir)
- elif expression_variety == "field_reference":
- _compute_constraints_of_field_reference(expression, ir)
- elif expression_variety == "builtin_reference":
- _compute_constraints_of_builtin_value(expression)
- elif expression_variety == "boolean_constant":
- _compute_constant_value_of_boolean_constant(expression)
- else:
- assert False, "Unknown expression variety {!r}".format(expression_variety)
- if expression.type.WhichOneof("type") == "integer":
- _assert_integer_constraints(expression)
+ """Adds appropriate bounding constraints to the given expression."""
+ if ir_util.is_constant_type(expression.type):
+ return
+ expression_variety = expression.WhichOneof("expression")
+ if expression_variety == "constant":
+ _compute_constant_value_of_constant(expression)
+ elif expression_variety == "constant_reference":
+ _compute_constant_value_of_constant_reference(expression, ir)
+ elif expression_variety == "function":
+ _compute_constraints_of_function(expression, ir)
+ elif expression_variety == "field_reference":
+ _compute_constraints_of_field_reference(expression, ir)
+ elif expression_variety == "builtin_reference":
+ _compute_constraints_of_builtin_value(expression)
+ elif expression_variety == "boolean_constant":
+ _compute_constant_value_of_boolean_constant(expression)
+ else:
+ assert False, "Unknown expression variety {!r}".format(expression_variety)
+ if expression.type.WhichOneof("type") == "integer":
+ _assert_integer_constraints(expression)
def _compute_constant_value_of_constant(expression):
- value = expression.constant.value
- expression.type.integer.modular_value = value
- expression.type.integer.minimum_value = value
- expression.type.integer.maximum_value = value
- expression.type.integer.modulus = "infinity"
+ value = expression.constant.value
+ expression.type.integer.modular_value = value
+ expression.type.integer.minimum_value = value
+ expression.type.integer.maximum_value = value
+ expression.type.integer.modulus = "infinity"
def _compute_constant_value_of_constant_reference(expression, ir):
- referred_object = ir_util.find_object(
- expression.constant_reference.canonical_name, ir)
- expression = ir_data_utils.builder(expression)
- if isinstance(referred_object, ir_data.EnumValue):
- compute_constraints_of_expression(referred_object.value, ir)
- assert ir_util.is_constant(referred_object.value)
- new_value = str(ir_util.constant_value(referred_object.value))
- expression.type.enumeration.value = new_value
- elif isinstance(referred_object, ir_data.Field):
- assert ir_util.field_is_virtual(referred_object), (
- "Non-virtual non-enum-value constant reference should have been caught "
- "in type_check.py")
- compute_constraints_of_expression(referred_object.read_transform, ir)
- expression.type.CopyFrom(referred_object.read_transform.type)
- else:
- assert False, "Unexpected constant reference type."
+ referred_object = ir_util.find_object(
+ expression.constant_reference.canonical_name, ir
+ )
+ expression = ir_data_utils.builder(expression)
+ if isinstance(referred_object, ir_data.EnumValue):
+ compute_constraints_of_expression(referred_object.value, ir)
+ assert ir_util.is_constant(referred_object.value)
+ new_value = str(ir_util.constant_value(referred_object.value))
+ expression.type.enumeration.value = new_value
+ elif isinstance(referred_object, ir_data.Field):
+ assert ir_util.field_is_virtual(referred_object), (
+ "Non-virtual non-enum-value constant reference should have been caught "
+ "in type_check.py"
+ )
+ compute_constraints_of_expression(referred_object.read_transform, ir)
+ expression.type.CopyFrom(referred_object.read_transform.type)
+ else:
+ assert False, "Unexpected constant reference type."
def _compute_constraints_of_function(expression, ir):
- """Computes the known constraints of the result of a function."""
- for arg in expression.function.args:
- compute_constraints_of_expression(arg, ir)
- op = expression.function.function
- if op in (ir_data.FunctionMapping.ADDITION, ir_data.FunctionMapping.SUBTRACTION):
- _compute_constraints_of_additive_operator(expression)
- elif op == ir_data.FunctionMapping.MULTIPLICATION:
- _compute_constraints_of_multiplicative_operator(expression)
- elif op in (ir_data.FunctionMapping.EQUALITY, ir_data.FunctionMapping.INEQUALITY,
- ir_data.FunctionMapping.LESS, ir_data.FunctionMapping.LESS_OR_EQUAL,
- ir_data.FunctionMapping.GREATER, ir_data.FunctionMapping.GREATER_OR_EQUAL,
- ir_data.FunctionMapping.AND, ir_data.FunctionMapping.OR):
- _compute_constant_value_of_comparison_operator(expression)
- elif op == ir_data.FunctionMapping.CHOICE:
- _compute_constraints_of_choice_operator(expression)
- elif op == ir_data.FunctionMapping.MAXIMUM:
- _compute_constraints_of_maximum_function(expression)
- elif op == ir_data.FunctionMapping.PRESENCE:
- _compute_constraints_of_existence_function(expression, ir)
- elif op in (ir_data.FunctionMapping.UPPER_BOUND, ir_data.FunctionMapping.LOWER_BOUND):
- _compute_constraints_of_bound_function(expression)
- else:
- assert False, "Unknown operator {!r}".format(op)
+ """Computes the known constraints of the result of a function."""
+ for arg in expression.function.args:
+ compute_constraints_of_expression(arg, ir)
+ op = expression.function.function
+ if op in (ir_data.FunctionMapping.ADDITION, ir_data.FunctionMapping.SUBTRACTION):
+ _compute_constraints_of_additive_operator(expression)
+ elif op == ir_data.FunctionMapping.MULTIPLICATION:
+ _compute_constraints_of_multiplicative_operator(expression)
+ elif op in (
+ ir_data.FunctionMapping.EQUALITY,
+ ir_data.FunctionMapping.INEQUALITY,
+ ir_data.FunctionMapping.LESS,
+ ir_data.FunctionMapping.LESS_OR_EQUAL,
+ ir_data.FunctionMapping.GREATER,
+ ir_data.FunctionMapping.GREATER_OR_EQUAL,
+ ir_data.FunctionMapping.AND,
+ ir_data.FunctionMapping.OR,
+ ):
+ _compute_constant_value_of_comparison_operator(expression)
+ elif op == ir_data.FunctionMapping.CHOICE:
+ _compute_constraints_of_choice_operator(expression)
+ elif op == ir_data.FunctionMapping.MAXIMUM:
+ _compute_constraints_of_maximum_function(expression)
+ elif op == ir_data.FunctionMapping.PRESENCE:
+ _compute_constraints_of_existence_function(expression, ir)
+ elif op in (
+ ir_data.FunctionMapping.UPPER_BOUND,
+ ir_data.FunctionMapping.LOWER_BOUND,
+ ):
+ _compute_constraints_of_bound_function(expression)
+ else:
+ assert False, "Unknown operator {!r}".format(op)
def _compute_constraints_of_existence_function(expression, ir):
- """Computes the constraints of a $has(field) expression."""
- field_path = expression.function.args[0].field_reference.path[-1]
- field = ir_util.find_object(field_path, ir)
- compute_constraints_of_expression(field.existence_condition, ir)
- ir_data_utils.builder(expression).type.CopyFrom(field.existence_condition.type)
+ """Computes the constraints of a $has(field) expression."""
+ field_path = expression.function.args[0].field_reference.path[-1]
+ field = ir_util.find_object(field_path, ir)
+ compute_constraints_of_expression(field.existence_condition, ir)
+ ir_data_utils.builder(expression).type.CopyFrom(field.existence_condition.type)
def _compute_constraints_of_field_reference(expression, ir):
- """Computes the constraints of a reference to a structure's field."""
- field_path = expression.field_reference.path[-1]
- field = ir_util.find_object(field_path, ir)
- if isinstance(field, ir_data.Field) and ir_util.field_is_virtual(field):
- # References to virtual fields should have the virtual field's constraints
- # copied over.
- compute_constraints_of_expression(field.read_transform, ir)
- ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
- return
- # Non-virtual non-integer fields do not (yet) have constraints.
- if expression.type.WhichOneof("type") == "integer":
- # TODO(bolms): These lines will need to change when support is added for
- # fixed-point types.
- expression.type.integer.modulus = "1"
- expression.type.integer.modular_value = "0"
- type_definition = ir_util.find_parent_object(field_path, ir)
- if isinstance(field, ir_data.Field):
- referrent_type = field.type
- else:
- referrent_type = field.physical_type_alias
- if referrent_type.HasField("size_in_bits"):
- type_size = ir_util.constant_value(referrent_type.size_in_bits)
- else:
- field_size = ir_util.constant_value(field.location.size)
- if field_size is None:
- type_size = None
- else:
- type_size = field_size * type_definition.addressable_unit
- assert referrent_type.HasField("atomic_type"), field
- assert not referrent_type.atomic_type.reference.canonical_name.module_file
- _set_integer_constraints_from_physical_type(
- expression, referrent_type, type_size)
+ """Computes the constraints of a reference to a structure's field."""
+ field_path = expression.field_reference.path[-1]
+ field = ir_util.find_object(field_path, ir)
+ if isinstance(field, ir_data.Field) and ir_util.field_is_virtual(field):
+ # References to virtual fields should have the virtual field's constraints
+ # copied over.
+ compute_constraints_of_expression(field.read_transform, ir)
+ ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
+ return
+ # Non-virtual non-integer fields do not (yet) have constraints.
+ if expression.type.WhichOneof("type") == "integer":
+ # TODO(bolms): These lines will need to change when support is added for
+ # fixed-point types.
+ expression.type.integer.modulus = "1"
+ expression.type.integer.modular_value = "0"
+ type_definition = ir_util.find_parent_object(field_path, ir)
+ if isinstance(field, ir_data.Field):
+ referrent_type = field.type
+ else:
+ referrent_type = field.physical_type_alias
+ if referrent_type.HasField("size_in_bits"):
+ type_size = ir_util.constant_value(referrent_type.size_in_bits)
+ else:
+ field_size = ir_util.constant_value(field.location.size)
+ if field_size is None:
+ type_size = None
+ else:
+ type_size = field_size * type_definition.addressable_unit
+ assert referrent_type.HasField("atomic_type"), field
+ assert not referrent_type.atomic_type.reference.canonical_name.module_file
+ _set_integer_constraints_from_physical_type(
+ expression, referrent_type, type_size
+ )
-def _set_integer_constraints_from_physical_type(
- expression, physical_type, type_size):
- """Copies the integer constraints of an expression from a physical type."""
- # SCAFFOLDING HACK: In order to keep changelists manageable, this hardcodes
- # the ranges for all of the Emboss Prelude integer types. This would break
- # any user-defined `external` integer types, but that feature isn't fully
- # implemented in the C++ backend, so it doesn't matter for now.
- #
- # Adding the attribute(s) for integer bounds will require new operators:
- # integer/flooring division, remainder, and exponentiation (2**N, 10**N).
- #
- # (Technically, there are a few sets of operators that would work: for
- # example, just the choice operator `?:` is sufficient, but very ugly.
- # Bitwise AND, bitshift, and exponentiation would also work, but `10**($bits
- # >> 2) * 2**($bits & 0b11) - 1` isn't quite as clear as `10**($bits // 4) *
- # 2**($bits % 4) - 1`, in my (bolms@) opinion.)
- #
- # TODO(bolms): Add a scheme for defining integer bounds on user-defined
- # external types.
- if type_size is None:
- # If the type_size is unknown, then we can't actually say anything about the
- # minimum and maximum values of the type. For UInt, Int, and Bcd, an error
- # will be thrown during the constraints check stage.
- expression.type.integer.minimum_value = "-infinity"
- expression.type.integer.maximum_value = "infinity"
- return
- name = tuple(physical_type.atomic_type.reference.canonical_name.object_path)
- if name == ("UInt",):
- expression.type.integer.minimum_value = "0"
- expression.type.integer.maximum_value = str(2**type_size - 1)
- elif name == ("Int",):
- expression.type.integer.minimum_value = str(-(2**(type_size - 1)))
- expression.type.integer.maximum_value = str(2**(type_size - 1) - 1)
- elif name == ("Bcd",):
- expression.type.integer.minimum_value = "0"
- expression.type.integer.maximum_value = str(
- 10**(type_size // 4) * 2**(type_size % 4) - 1)
- else:
- assert False, "Unknown integral type " + ".".join(name)
+def _set_integer_constraints_from_physical_type(expression, physical_type, type_size):
+ """Copies the integer constraints of an expression from a physical type."""
+ # SCAFFOLDING HACK: In order to keep changelists manageable, this hardcodes
+ # the ranges for all of the Emboss Prelude integer types. This would break
+ # any user-defined `external` integer types, but that feature isn't fully
+ # implemented in the C++ backend, so it doesn't matter for now.
+ #
+ # Adding the attribute(s) for integer bounds will require new operators:
+ # integer/flooring division, remainder, and exponentiation (2**N, 10**N).
+ #
+ # (Technically, there are a few sets of operators that would work: for
+ # example, just the choice operator `?:` is sufficient, but very ugly.
+ # Bitwise AND, bitshift, and exponentiation would also work, but `10**($bits
+ # >> 2) * 2**($bits & 0b11) - 1` isn't quite as clear as `10**($bits // 4) *
+ # 2**($bits % 4) - 1`, in my (bolms@) opinion.)
+ #
+ # TODO(bolms): Add a scheme for defining integer bounds on user-defined
+ # external types.
+ if type_size is None:
+ # If the type_size is unknown, then we can't actually say anything about the
+ # minimum and maximum values of the type. For UInt, Int, and Bcd, an error
+ # will be thrown during the constraints check stage.
+ expression.type.integer.minimum_value = "-infinity"
+ expression.type.integer.maximum_value = "infinity"
+ return
+ name = tuple(physical_type.atomic_type.reference.canonical_name.object_path)
+ if name == ("UInt",):
+ expression.type.integer.minimum_value = "0"
+ expression.type.integer.maximum_value = str(2**type_size - 1)
+ elif name == ("Int",):
+ expression.type.integer.minimum_value = str(-(2 ** (type_size - 1)))
+ expression.type.integer.maximum_value = str(2 ** (type_size - 1) - 1)
+ elif name == ("Bcd",):
+ expression.type.integer.minimum_value = "0"
+ expression.type.integer.maximum_value = str(
+ 10 ** (type_size // 4) * 2 ** (type_size % 4) - 1
+ )
+ else:
+ assert False, "Unknown integral type " + ".".join(name)
def _compute_constraints_of_parameter(parameter):
- if parameter.type.WhichOneof("type") == "integer":
- type_size = ir_util.constant_value(
- parameter.physical_type_alias.size_in_bits)
- _set_integer_constraints_from_physical_type(
- parameter, parameter.physical_type_alias, type_size)
+ if parameter.type.WhichOneof("type") == "integer":
+ type_size = ir_util.constant_value(parameter.physical_type_alias.size_in_bits)
+ _set_integer_constraints_from_physical_type(
+ parameter, parameter.physical_type_alias, type_size
+ )
def _compute_constraints_of_builtin_value(expression):
- """Computes the constraints of a builtin (like $static_size_in_bits)."""
- name = expression.builtin_reference.canonical_name.object_path[0]
- if name == "$static_size_in_bits":
- expression.type.integer.modulus = "1"
- expression.type.integer.modular_value = "0"
- expression.type.integer.minimum_value = "0"
- # The maximum theoretically-supported size of something is 2**64 bytes,
- # which is 2**64 * 8 bits.
- #
- # Really, $static_size_in_bits is only valid in expressions that have to be
- # evaluated at compile time anyway, so it doesn't really matter if the
- # bounds are excessive.
- expression.type.integer.maximum_value = "infinity"
- elif name == "$is_statically_sized":
- # No bounds on a boolean variable.
- pass
- elif name == "$logical_value":
- # $logical_value is the placeholder used in inferred write-through
- # transformations.
- #
- # Only integers (currently) have "real" write-through transformations, but
- # fields that would otherwise be straight aliases, but which have a
- # [requires] attribute, are elevated to write-through fields, so that the
- # [requires] clause can be checked in Write, CouldWriteValue, TryToWrite,
- # Read, and Ok.
- if expression.type.WhichOneof("type") == "integer":
- assert expression.type.integer.modulus
- assert expression.type.integer.modular_value
- assert expression.type.integer.minimum_value
- assert expression.type.integer.maximum_value
- elif expression.type.WhichOneof("type") == "enumeration":
- assert expression.type.enumeration.name
- elif expression.type.WhichOneof("type") == "boolean":
- pass
+ """Computes the constraints of a builtin (like $static_size_in_bits)."""
+ name = expression.builtin_reference.canonical_name.object_path[0]
+ if name == "$static_size_in_bits":
+ expression.type.integer.modulus = "1"
+ expression.type.integer.modular_value = "0"
+ expression.type.integer.minimum_value = "0"
+ # The maximum theoretically-supported size of something is 2**64 bytes,
+ # which is 2**64 * 8 bits.
+ #
+ # Really, $static_size_in_bits is only valid in expressions that have to be
+ # evaluated at compile time anyway, so it doesn't really matter if the
+ # bounds are excessive.
+ expression.type.integer.maximum_value = "infinity"
+ elif name == "$is_statically_sized":
+ # No bounds on a boolean variable.
+ pass
+ elif name == "$logical_value":
+ # $logical_value is the placeholder used in inferred write-through
+ # transformations.
+ #
+ # Only integers (currently) have "real" write-through transformations, but
+ # fields that would otherwise be straight aliases, but which have a
+ # [requires] attribute, are elevated to write-through fields, so that the
+ # [requires] clause can be checked in Write, CouldWriteValue, TryToWrite,
+ # Read, and Ok.
+ if expression.type.WhichOneof("type") == "integer":
+ assert expression.type.integer.modulus
+ assert expression.type.integer.modular_value
+ assert expression.type.integer.minimum_value
+ assert expression.type.integer.maximum_value
+ elif expression.type.WhichOneof("type") == "enumeration":
+ assert expression.type.enumeration.name
+ elif expression.type.WhichOneof("type") == "boolean":
+ pass
+ else:
+ assert False, "Unexpected type for $logical_value"
else:
- assert False, "Unexpected type for $logical_value"
- else:
- assert False, "Unknown builtin " + name
+ assert False, "Unknown builtin " + name
def _compute_constant_value_of_boolean_constant(expression):
- expression.type.boolean.value = expression.boolean_constant.value
+ expression.type.boolean.value = expression.boolean_constant.value
def _add(a, b):
- """Adds a and b, where a and b are ints, "infinity", or "-infinity"."""
- if a in ("infinity", "-infinity"):
- a, b = b, a
- if b == "infinity":
- assert a != "-infinity"
- return "infinity"
- if b == "-infinity":
- assert a != "infinity"
- return "-infinity"
- return int(a) + int(b)
+ """Adds a and b, where a and b are ints, "infinity", or "-infinity"."""
+ if a in ("infinity", "-infinity"):
+ a, b = b, a
+ if b == "infinity":
+ assert a != "-infinity"
+ return "infinity"
+ if b == "-infinity":
+ assert a != "infinity"
+ return "-infinity"
+ return int(a) + int(b)
def _sub(a, b):
- """Subtracts b from a, where a and b are ints, "infinity", or "-infinity"."""
- if b == "infinity":
- return _add(a, "-infinity")
- if b == "-infinity":
- return _add(a, "infinity")
- return _add(a, -int(b))
+ """Subtracts b from a, where a and b are ints, "infinity", or "-infinity"."""
+ if b == "infinity":
+ return _add(a, "-infinity")
+ if b == "-infinity":
+ return _add(a, "infinity")
+ return _add(a, -int(b))
def _sign(a):
- """Returns 1 if a > 0, 0 if a == 0, and -1 if a < 0."""
- if a == "infinity":
- return 1
- if a == "-infinity":
- return -1
- if int(a) > 0:
- return 1
- if int(a) < 0:
- return -1
- return 0
+ """Returns 1 if a > 0, 0 if a == 0, and -1 if a < 0."""
+ if a == "infinity":
+ return 1
+ if a == "-infinity":
+ return -1
+ if int(a) > 0:
+ return 1
+ if int(a) < 0:
+ return -1
+ return 0
def _mul(a, b):
- """Multiplies a and b, where a and b are ints, "infinity", or "-infinity"."""
- if _is_infinite(a):
- a, b = b, a
- if _is_infinite(b):
- sign = _sign(a) * _sign(b)
- if sign > 0:
- return "infinity"
- if sign < 0:
- return "-infinity"
- return 0
- return int(a) * int(b)
+ """Multiplies a and b, where a and b are ints, "infinity", or "-infinity"."""
+ if _is_infinite(a):
+ a, b = b, a
+ if _is_infinite(b):
+ sign = _sign(a) * _sign(b)
+ if sign > 0:
+ return "infinity"
+ if sign < 0:
+ return "-infinity"
+ return 0
+ return int(a) * int(b)
def _is_infinite(a):
- return a in ("infinity", "-infinity")
+ return a in ("infinity", "-infinity")
def _max(a):
- """Returns max of a, where elements are ints, "infinity", or "-infinity"."""
- if any(n == "infinity" for n in a):
- return "infinity"
- if all(n == "-infinity" for n in a):
- return "-infinity"
- return max(int(n) for n in a if not _is_infinite(n))
+ """Returns max of a, where elements are ints, "infinity", or "-infinity"."""
+ if any(n == "infinity" for n in a):
+ return "infinity"
+ if all(n == "-infinity" for n in a):
+ return "-infinity"
+ return max(int(n) for n in a if not _is_infinite(n))
def _min(a):
- """Returns min of a, where elements are ints, "infinity", or "-infinity"."""
- if any(n == "-infinity" for n in a):
- return "-infinity"
- if all(n == "infinity" for n in a):
- return "infinity"
- return min(int(n) for n in a if not _is_infinite(n))
+ """Returns min of a, where elements are ints, "infinity", or "-infinity"."""
+ if any(n == "-infinity" for n in a):
+ return "-infinity"
+ if all(n == "infinity" for n in a):
+ return "infinity"
+ return min(int(n) for n in a if not _is_infinite(n))
def _compute_constraints_of_additive_operator(expression):
- """Computes the modular value of an additive expression."""
- funcs = {
- ir_data.FunctionMapping.ADDITION: _add,
- ir_data.FunctionMapping.SUBTRACTION: _sub,
- }
- func = funcs[expression.function.function]
- args = expression.function.args
- for arg in args:
- assert arg.type.integer.modular_value, str(expression)
- left, right = args
- unadjusted_modular_value = func(left.type.integer.modular_value,
- right.type.integer.modular_value)
- new_modulus = _greatest_common_divisor(left.type.integer.modulus,
- right.type.integer.modulus)
- expression.type.integer.modulus = str(new_modulus)
- if new_modulus == "infinity":
- expression.type.integer.modular_value = str(unadjusted_modular_value)
- else:
- expression.type.integer.modular_value = str(unadjusted_modular_value %
- new_modulus)
- lmax = left.type.integer.maximum_value
- lmin = left.type.integer.minimum_value
- if expression.function.function == ir_data.FunctionMapping.SUBTRACTION:
- rmax = right.type.integer.minimum_value
- rmin = right.type.integer.maximum_value
- else:
- rmax = right.type.integer.maximum_value
- rmin = right.type.integer.minimum_value
- expression.type.integer.minimum_value = str(func(lmin, rmin))
- expression.type.integer.maximum_value = str(func(lmax, rmax))
+ """Computes the modular value of an additive expression."""
+ funcs = {
+ ir_data.FunctionMapping.ADDITION: _add,
+ ir_data.FunctionMapping.SUBTRACTION: _sub,
+ }
+ func = funcs[expression.function.function]
+ args = expression.function.args
+ for arg in args:
+ assert arg.type.integer.modular_value, str(expression)
+ left, right = args
+ unadjusted_modular_value = func(
+ left.type.integer.modular_value, right.type.integer.modular_value
+ )
+ new_modulus = _greatest_common_divisor(
+ left.type.integer.modulus, right.type.integer.modulus
+ )
+ expression.type.integer.modulus = str(new_modulus)
+ if new_modulus == "infinity":
+ expression.type.integer.modular_value = str(unadjusted_modular_value)
+ else:
+ expression.type.integer.modular_value = str(
+ unadjusted_modular_value % new_modulus
+ )
+ lmax = left.type.integer.maximum_value
+ lmin = left.type.integer.minimum_value
+ if expression.function.function == ir_data.FunctionMapping.SUBTRACTION:
+ rmax = right.type.integer.minimum_value
+ rmin = right.type.integer.maximum_value
+ else:
+ rmax = right.type.integer.maximum_value
+ rmin = right.type.integer.minimum_value
+ expression.type.integer.minimum_value = str(func(lmin, rmin))
+ expression.type.integer.maximum_value = str(func(lmax, rmax))
def _compute_constraints_of_multiplicative_operator(expression):
- """Computes the modular value of a multiplicative expression."""
- bounds = [arg.type.integer for arg in expression.function.args]
+ """Computes the modular value of a multiplicative expression."""
+ bounds = [arg.type.integer for arg in expression.function.args]
- # The minimum and maximum values can come from any of the four pairings of
- # (left min, left max) with (right min, right max), depending on the signs and
- # magnitudes of the minima and maxima. E.g.:
- #
- # max = left max * right max: [ 2, 3] * [ 2, 3]
- # max = left min * right min: [-3, -2] * [-3, -2]
- # max = left max * right min: [-3, -2] * [ 2, 3]
- # max = left min * right max: [ 2, 3] * [-3, -2]
- # max = left max * right max: [-2, 3] * [-2, 3]
- # max = left min * right min: [-3, 2] * [-3, 2]
- #
- # For uncorrelated multiplication, the minimum and maximum will always come
- # from multiplying one extreme by another: if x is nonzero, then
- #
- # (y + e) * x > y * x || (y - e) * x > y * x
- #
- # for arbitrary nonzero e, so the extrema can only occur when we either cannot
- # add or cannot subtract e.
- #
- # Correlated multiplication (e.g., `x * x`) can have tighter bounds, but
- # Emboss is not currently trying to be that smart.
- lmin, lmax = bounds[0].minimum_value, bounds[0].maximum_value
- rmin, rmax = bounds[1].minimum_value, bounds[1].maximum_value
- extrema = [_mul(lmax, rmax), _mul(lmin, rmax), #
- _mul(lmax, rmin), _mul(lmin, rmin)]
- expression.type.integer.minimum_value = str(_min(extrema))
- expression.type.integer.maximum_value = str(_max(extrema))
-
- if all(bound.modulus == "infinity" for bound in bounds):
- # If both sides are constant, the result is constant.
- expression.type.integer.modulus = "infinity"
- expression.type.integer.modular_value = str(int(bounds[0].modular_value) *
- int(bounds[1].modular_value))
- return
-
- if any(bound.modulus == "infinity" for bound in bounds):
- # If one side is constant and the other is not, then the non-constant
- # modulus and modular_value can both be multiplied by the constant. E.g.,
- # if `a` is congruent to 3 mod 5, then `4 * a` will be congruent to 12 mod
- # 20:
+ # The minimum and maximum values can come from any of the four pairings of
+ # (left min, left max) with (right min, right max), depending on the signs and
+ # magnitudes of the minima and maxima. E.g.:
#
- # a = ... | 4 * a = ... | 4 * a mod 20 = ...
- # 3 | 12 | 12
- # 8 | 32 | 12
- # 13 | 52 | 12
- # 18 | 72 | 12
- # 23 | 92 | 12
- # 28 | 112 | 12
- # 33 | 132 | 12
+ # max = left max * right max: [ 2, 3] * [ 2, 3]
+ # max = left min * right min: [-3, -2] * [-3, -2]
+ # max = left max * right min: [-3, -2] * [ 2, 3]
+ # max = left min * right max: [ 2, 3] * [-3, -2]
+ # max = left max * right max: [-2, 3] * [-2, 3]
+ # max = left min * right min: [-3, 2] * [-3, 2]
#
- # This is trivially shown by noting that the difference between consecutive
- # possible values for `4 * a` always differ by 20.
- if bounds[0].modulus == "infinity":
- constant, variable = bounds
- else:
- variable, constant = bounds
- if int(constant.modular_value) == 0:
- # If the constant is 0, the result is 0, no matter what the variable side
- # is.
- expression.type.integer.modulus = "infinity"
- expression.type.integer.modular_value = "0"
- return
- new_modulus = int(variable.modulus) * abs(int(constant.modular_value))
- expression.type.integer.modulus = str(new_modulus)
- # The `% new_modulus` will force the `modular_value` to be positive, even
- # when `constant.modular_value` is negative.
+ # For uncorrelated multiplication, the minimum and maximum will always come
+ # from multiplying one extreme by another: if x is nonzero, then
+ #
+ # (y + e) * x > y * x || (y - e) * x > y * x
+ #
+ # for arbitrary nonzero e, so the extrema can only occur when we either cannot
+ # add or cannot subtract e.
+ #
+ # Correlated multiplication (e.g., `x * x`) can have tighter bounds, but
+ # Emboss is not currently trying to be that smart.
+ lmin, lmax = bounds[0].minimum_value, bounds[0].maximum_value
+ rmin, rmax = bounds[1].minimum_value, bounds[1].maximum_value
+ extrema = [
+ _mul(lmax, rmax),
+ _mul(lmin, rmax), #
+ _mul(lmax, rmin),
+ _mul(lmin, rmin),
+ ]
+ expression.type.integer.minimum_value = str(_min(extrema))
+ expression.type.integer.maximum_value = str(_max(extrema))
+
+ if all(bound.modulus == "infinity" for bound in bounds):
+ # If both sides are constant, the result is constant.
+ expression.type.integer.modulus = "infinity"
+ expression.type.integer.modular_value = str(
+ int(bounds[0].modular_value) * int(bounds[1].modular_value)
+ )
+ return
+
+ if any(bound.modulus == "infinity" for bound in bounds):
+ # If one side is constant and the other is not, then the non-constant
+ # modulus and modular_value can both be multiplied by the constant. E.g.,
+ # if `a` is congruent to 3 mod 5, then `4 * a` will be congruent to 12 mod
+ # 20:
+ #
+ # a = ... | 4 * a = ... | 4 * a mod 20 = ...
+ # 3 | 12 | 12
+ # 8 | 32 | 12
+ # 13 | 52 | 12
+ # 18 | 72 | 12
+ # 23 | 92 | 12
+ # 28 | 112 | 12
+ # 33 | 132 | 12
+ #
+ # This is trivially shown by noting that the difference between consecutive
+ # possible values for `4 * a` always differ by 20.
+ if bounds[0].modulus == "infinity":
+ constant, variable = bounds
+ else:
+ variable, constant = bounds
+ if int(constant.modular_value) == 0:
+ # If the constant is 0, the result is 0, no matter what the variable side
+ # is.
+ expression.type.integer.modulus = "infinity"
+ expression.type.integer.modular_value = "0"
+ return
+ new_modulus = int(variable.modulus) * abs(int(constant.modular_value))
+ expression.type.integer.modulus = str(new_modulus)
+ # The `% new_modulus` will force the `modular_value` to be positive, even
+ # when `constant.modular_value` is negative.
+ expression.type.integer.modular_value = str(
+ int(variable.modular_value) * int(constant.modular_value) % new_modulus
+ )
+ return
+
+ # If neither side is constant, then the result is more complex. Full proof is
+ # available in g3doc/modular_congruence_multiplication_proof.md
+ #
+ # Essentially, if:
+ #
+ # l == _ * l_mod + l_mv
+ # r == _ * r_mod + r_mv
+ #
+ # Then we find l_mod0 and r_mod0 in:
+ #
+ # l == (_ * l_mod_nz + l_mv_nz) * l_mod0
+ # r == (_ * r_mod_nz + r_mv_nz) * r_mod0
+ #
+ # And finally conclude:
+ #
+ # l * r == _ * GCD(l_mod_nz, r_mod_nz) * l_mod0 * r_mod0 + l_mv * r_mv
+ product_of_zero_congruence_moduli = 1
+ product_of_modular_values = 1
+ nonzero_congruence_moduli = []
+ for bound in bounds:
+ zero_congruence_modulus = _greatest_common_divisor(
+ bound.modulus, bound.modular_value
+ )
+ assert int(bound.modulus) % zero_congruence_modulus == 0
+ product_of_zero_congruence_moduli *= zero_congruence_modulus
+ product_of_modular_values *= int(bound.modular_value)
+ nonzero_congruence_moduli.append(int(bound.modulus) // zero_congruence_modulus)
+ shared_nonzero_congruence_modulus = _greatest_common_divisor(
+ nonzero_congruence_moduli[0], nonzero_congruence_moduli[1]
+ )
+ final_modulus = (
+ shared_nonzero_congruence_modulus * product_of_zero_congruence_moduli
+ )
+ expression.type.integer.modulus = str(final_modulus)
expression.type.integer.modular_value = str(
- int(variable.modular_value) * int(constant.modular_value) % new_modulus)
- return
-
- # If neither side is constant, then the result is more complex. Full proof is
- # available in g3doc/modular_congruence_multiplication_proof.md
- #
- # Essentially, if:
- #
- # l == _ * l_mod + l_mv
- # r == _ * r_mod + r_mv
- #
- # Then we find l_mod0 and r_mod0 in:
- #
- # l == (_ * l_mod_nz + l_mv_nz) * l_mod0
- # r == (_ * r_mod_nz + r_mv_nz) * r_mod0
- #
- # And finally conclude:
- #
- # l * r == _ * GCD(l_mod_nz, r_mod_nz) * l_mod0 * r_mod0 + l_mv * r_mv
- product_of_zero_congruence_moduli = 1
- product_of_modular_values = 1
- nonzero_congruence_moduli = []
- for bound in bounds:
- zero_congruence_modulus = _greatest_common_divisor(bound.modulus,
- bound.modular_value)
- assert int(bound.modulus) % zero_congruence_modulus == 0
- product_of_zero_congruence_moduli *= zero_congruence_modulus
- product_of_modular_values *= int(bound.modular_value)
- nonzero_congruence_moduli.append(int(bound.modulus) //
- zero_congruence_modulus)
- shared_nonzero_congruence_modulus = _greatest_common_divisor(
- nonzero_congruence_moduli[0], nonzero_congruence_moduli[1])
- final_modulus = (shared_nonzero_congruence_modulus *
- product_of_zero_congruence_moduli)
- expression.type.integer.modulus = str(final_modulus)
- expression.type.integer.modular_value = str(product_of_modular_values %
- final_modulus)
+ product_of_modular_values % final_modulus
+ )
def _assert_integer_constraints(expression):
- """Asserts that the integer bounds of expression are self-consistent.
+ """Asserts that the integer bounds of expression are self-consistent.
- Asserts that `minimum_value` and `maximum_value` are congruent to
- `modular_value` modulo `modulus`.
+ Asserts that `minimum_value` and `maximum_value` are congruent to
+ `modular_value` modulo `modulus`.
- If `modulus` is "infinity", asserts that `minimum_value`, `maximum_value`, and
- `modular_value` are all equal.
+ If `modulus` is "infinity", asserts that `minimum_value`, `maximum_value`, and
+ `modular_value` are all equal.
- If `minimum_value` is equal to `maximum_value`, asserts that `modular_value`
- is equal to both, and that `modulus` is "infinity".
+ If `minimum_value` is equal to `maximum_value`, asserts that `modular_value`
+ is equal to both, and that `modulus` is "infinity".
- Arguments:
- expression: an expression with type.integer
+ Arguments:
+ expression: an expression with type.integer
- Returns:
- None
- """
- bounds = expression.type.integer
- if bounds.modulus == "infinity":
- assert bounds.minimum_value == bounds.modular_value
- assert bounds.maximum_value == bounds.modular_value
- return
- modulus = int(bounds.modulus)
- assert modulus > 0
- if bounds.minimum_value != "-infinity":
- assert int(bounds.minimum_value) % modulus == int(bounds.modular_value)
- if bounds.maximum_value != "infinity":
- assert int(bounds.maximum_value) % modulus == int(bounds.modular_value)
- if bounds.minimum_value == bounds.maximum_value:
- # TODO(bolms): I believe there are situations using the not-yet-implemented
- # integer division operator that would trigger these asserts, so they should
- # be turned into assignments (with corresponding tests) when implementing
- # division.
- assert bounds.modular_value == bounds.minimum_value
- assert bounds.modulus == "infinity"
- if bounds.minimum_value != "-infinity" and bounds.maximum_value != "infinity":
- assert int(bounds.minimum_value) <= int(bounds.maximum_value)
+ Returns:
+ None
+ """
+ bounds = expression.type.integer
+ if bounds.modulus == "infinity":
+ assert bounds.minimum_value == bounds.modular_value
+ assert bounds.maximum_value == bounds.modular_value
+ return
+ modulus = int(bounds.modulus)
+ assert modulus > 0
+ if bounds.minimum_value != "-infinity":
+ assert int(bounds.minimum_value) % modulus == int(bounds.modular_value)
+ if bounds.maximum_value != "infinity":
+ assert int(bounds.maximum_value) % modulus == int(bounds.modular_value)
+ if bounds.minimum_value == bounds.maximum_value:
+ # TODO(bolms): I believe there are situations using the not-yet-implemented
+ # integer division operator that would trigger these asserts, so they should
+ # be turned into assignments (with corresponding tests) when implementing
+ # division.
+ assert bounds.modular_value == bounds.minimum_value
+ assert bounds.modulus == "infinity"
+ if bounds.minimum_value != "-infinity" and bounds.maximum_value != "infinity":
+ assert int(bounds.minimum_value) <= int(bounds.maximum_value)
def _compute_constant_value_of_comparison_operator(expression):
- """Computes the constant value, if any, of a comparison operator."""
- args = expression.function.args
- if all(ir_util.is_constant(arg) for arg in args):
- functions = {
- ir_data.FunctionMapping.EQUALITY: operator.eq,
- ir_data.FunctionMapping.INEQUALITY: operator.ne,
- ir_data.FunctionMapping.LESS: operator.lt,
- ir_data.FunctionMapping.LESS_OR_EQUAL: operator.le,
- ir_data.FunctionMapping.GREATER: operator.gt,
- ir_data.FunctionMapping.GREATER_OR_EQUAL: operator.ge,
- ir_data.FunctionMapping.AND: operator.and_,
- ir_data.FunctionMapping.OR: operator.or_,
- }
- func = functions[expression.function.function]
- expression.type.boolean.value = func(
- *[ir_util.constant_value(arg) for arg in args])
+ """Computes the constant value, if any, of a comparison operator."""
+ args = expression.function.args
+ if all(ir_util.is_constant(arg) for arg in args):
+ functions = {
+ ir_data.FunctionMapping.EQUALITY: operator.eq,
+ ir_data.FunctionMapping.INEQUALITY: operator.ne,
+ ir_data.FunctionMapping.LESS: operator.lt,
+ ir_data.FunctionMapping.LESS_OR_EQUAL: operator.le,
+ ir_data.FunctionMapping.GREATER: operator.gt,
+ ir_data.FunctionMapping.GREATER_OR_EQUAL: operator.ge,
+ ir_data.FunctionMapping.AND: operator.and_,
+ ir_data.FunctionMapping.OR: operator.or_,
+ }
+ func = functions[expression.function.function]
+ expression.type.boolean.value = func(
+ *[ir_util.constant_value(arg) for arg in args]
+ )
def _compute_constraints_of_bound_function(expression):
- """Computes the constraints of $upper_bound or $lower_bound."""
- if expression.function.function == ir_data.FunctionMapping.UPPER_BOUND:
- value = expression.function.args[0].type.integer.maximum_value
- elif expression.function.function == ir_data.FunctionMapping.LOWER_BOUND:
- value = expression.function.args[0].type.integer.minimum_value
- else:
- assert False, "Non-bound function"
- expression.type.integer.minimum_value = value
- expression.type.integer.maximum_value = value
- expression.type.integer.modular_value = value
- expression.type.integer.modulus = "infinity"
+ """Computes the constraints of $upper_bound or $lower_bound."""
+ if expression.function.function == ir_data.FunctionMapping.UPPER_BOUND:
+ value = expression.function.args[0].type.integer.maximum_value
+ elif expression.function.function == ir_data.FunctionMapping.LOWER_BOUND:
+ value = expression.function.args[0].type.integer.minimum_value
+ else:
+ assert False, "Non-bound function"
+ expression.type.integer.minimum_value = value
+ expression.type.integer.maximum_value = value
+ expression.type.integer.modular_value = value
+ expression.type.integer.modulus = "infinity"
def _compute_constraints_of_maximum_function(expression):
- """Computes the constraints of the $max function."""
- assert expression.type.WhichOneof("type") == "integer"
- args = expression.function.args
- assert args[0].type.WhichOneof("type") == "integer"
- # The minimum value of the result occurs when every argument takes its minimum
- # value, which means that the minimum result is the maximum-of-minimums.
- expression.type.integer.minimum_value = str(_max(
- [arg.type.integer.minimum_value for arg in args]))
- # The maximum result is the maximum-of-maximums.
- expression.type.integer.maximum_value = str(_max(
- [arg.type.integer.maximum_value for arg in args]))
- # If the expression is dominated by a constant factor, then the result is
- # constant. I (bolms@) believe this is the only case where
- # _compute_constraints_of_maximum_function might violate the assertions in
- # _assert_integer_constraints.
- if (expression.type.integer.minimum_value ==
- expression.type.integer.maximum_value):
- expression.type.integer.modular_value = (
- expression.type.integer.minimum_value)
- expression.type.integer.modulus = "infinity"
- return
- result_modulus = args[0].type.integer.modulus
- result_modular_value = args[0].type.integer.modular_value
- # The result of $max(a, b) could be either a or b, which means that the result
- # of $max(a, b) uses the _shared_modular_value() of a and b, just like the
- # choice operator '?:'.
- #
- # This also takes advantage of the fact that $max(a, b, c, d, ...) is
- # equivalent to $max(a, $max(b, $max(c, $max(d, ...)))), so it is valid to
- # call _shared_modular_value() in a loop.
- for arg in args[1:]:
- # TODO(bolms): I think the bounds could be tigher in some cases where
- # arg.maximum_value is less than the new expression.minimum_value, and
- # in some very specific cases where arg.maximum_value is greater than the
- # new expression.minimum_value, but arg.maximum_value - arg.modulus is less
- # than expression.minimum_value.
- result_modulus, result_modular_value = _shared_modular_value(
- (result_modulus, result_modular_value),
- (arg.type.integer.modulus, arg.type.integer.modular_value))
- expression.type.integer.modulus = str(result_modulus)
- expression.type.integer.modular_value = str(result_modular_value)
+ """Computes the constraints of the $max function."""
+ assert expression.type.WhichOneof("type") == "integer"
+ args = expression.function.args
+ assert args[0].type.WhichOneof("type") == "integer"
+ # The minimum value of the result occurs when every argument takes its minimum
+ # value, which means that the minimum result is the maximum-of-minimums.
+ expression.type.integer.minimum_value = str(
+ _max([arg.type.integer.minimum_value for arg in args])
+ )
+ # The maximum result is the maximum-of-maximums.
+ expression.type.integer.maximum_value = str(
+ _max([arg.type.integer.maximum_value for arg in args])
+ )
+ # If the expression is dominated by a constant factor, then the result is
+ # constant. I (bolms@) believe this is the only case where
+ # _compute_constraints_of_maximum_function might violate the assertions in
+ # _assert_integer_constraints.
+ if expression.type.integer.minimum_value == expression.type.integer.maximum_value:
+ expression.type.integer.modular_value = expression.type.integer.minimum_value
+ expression.type.integer.modulus = "infinity"
+ return
+ result_modulus = args[0].type.integer.modulus
+ result_modular_value = args[0].type.integer.modular_value
+ # The result of $max(a, b) could be either a or b, which means that the result
+ # of $max(a, b) uses the _shared_modular_value() of a and b, just like the
+ # choice operator '?:'.
+ #
+ # This also takes advantage of the fact that $max(a, b, c, d, ...) is
+ # equivalent to $max(a, $max(b, $max(c, $max(d, ...)))), so it is valid to
+ # call _shared_modular_value() in a loop.
+ for arg in args[1:]:
+ # TODO(bolms): I think the bounds could be tigher in some cases where
+ # arg.maximum_value is less than the new expression.minimum_value, and
+ # in some very specific cases where arg.maximum_value is greater than the
+ # new expression.minimum_value, but arg.maximum_value - arg.modulus is less
+ # than expression.minimum_value.
+ result_modulus, result_modular_value = _shared_modular_value(
+ (result_modulus, result_modular_value),
+ (arg.type.integer.modulus, arg.type.integer.modular_value),
+ )
+ expression.type.integer.modulus = str(result_modulus)
+ expression.type.integer.modular_value = str(result_modular_value)
def _shared_modular_value(left, right):
- """Returns the shared modulus and modular value of left and right.
+ """Returns the shared modulus and modular value of left and right.
- Arguments:
- left: A tuple of (modulus, modular value)
- right: A tuple of (modulus, modular value)
+ Arguments:
+ left: A tuple of (modulus, modular value)
+ right: A tuple of (modulus, modular value)
- Returns:
- A tuple of (modulus, modular_value) such that:
+ Returns:
+ A tuple of (modulus, modular_value) such that:
- left.modulus % result.modulus == 0
- right.modulus % result.modulus == 0
- left.modular_value % result.modulus = result.modular_value
- right.modular_value % result.modulus = result.modular_value
+ left.modulus % result.modulus == 0
+ right.modulus % result.modulus == 0
+ left.modular_value % result.modulus = result.modular_value
+ right.modular_value % result.modulus = result.modular_value
- That is, the result.modulus and result.modular_value will be compatible
- with, but (possibly) less restrictive than both left.(modulus,
- modular_value) and right.(modulus, modular_value).
- """
- left_modulus, left_modular_value = left
- right_modulus, right_modular_value = right
- # The combined modulus is gcd(gcd(left_modulus, right_modulus),
- # left_modular_value - right_modular_value).
- #
- # The inner gcd normalizes the left_modulus and right_modulus, but can leave
- # incompatible modular_values. The outer gcd finds a modulus to which both
- # modular_values are congruent. Some examples:
- #
- # left | right | res
- # --------------+----------------+--------------------
- # l % 12 == 7 | r % 20 == 15 | res % 4 == 3
- # l == 35 | r % 20 == 15 | res % 20 == 15
- # l % 24 == 15 | r % 12 == 7 | res % 4 == 3
- # l % 20 == 15 | r % 20 == 10 | res % 5 == 0
- # l % 20 == 16 | r % 20 == 11 | res % 5 == 1
- # l == 10 | r == 7 | res % 3 == 1
- # l == 4 | r == 4 | res == 4
- #
- # The cases where one side or the other are constant are handled
- # automatically by the fact that _greatest_common_divisor("infinity", x)
- # is x.
- common_modulus = _greatest_common_divisor(left_modulus, right_modulus)
- new_modulus = _greatest_common_divisor(
- common_modulus, abs(int(left_modular_value) - int(right_modular_value)))
- if new_modulus == "infinity":
- # The only way for the new_modulus to come out as "infinity" *should* be
- # if both if_true and if_false have the same constant value.
- assert left_modular_value == right_modular_value
- assert left_modulus == right_modulus == "infinity"
- return new_modulus, left_modular_value
- else:
- assert (int(left_modular_value) % new_modulus ==
- int(right_modular_value) % new_modulus)
- return new_modulus, int(left_modular_value) % new_modulus
+ That is, the result.modulus and result.modular_value will be compatible
+ with, but (possibly) less restrictive than both left.(modulus,
+ modular_value) and right.(modulus, modular_value).
+ """
+ left_modulus, left_modular_value = left
+ right_modulus, right_modular_value = right
+ # The combined modulus is gcd(gcd(left_modulus, right_modulus),
+ # left_modular_value - right_modular_value).
+ #
+ # The inner gcd normalizes the left_modulus and right_modulus, but can leave
+ # incompatible modular_values. The outer gcd finds a modulus to which both
+ # modular_values are congruent. Some examples:
+ #
+ # left | right | res
+ # --------------+----------------+--------------------
+ # l % 12 == 7 | r % 20 == 15 | res % 4 == 3
+ # l == 35 | r % 20 == 15 | res % 20 == 15
+ # l % 24 == 15 | r % 12 == 7 | res % 4 == 3
+ # l % 20 == 15 | r % 20 == 10 | res % 5 == 0
+ # l % 20 == 16 | r % 20 == 11 | res % 5 == 1
+ # l == 10 | r == 7 | res % 3 == 1
+ # l == 4 | r == 4 | res == 4
+ #
+ # The cases where one side or the other are constant are handled
+ # automatically by the fact that _greatest_common_divisor("infinity", x)
+ # is x.
+ common_modulus = _greatest_common_divisor(left_modulus, right_modulus)
+ new_modulus = _greatest_common_divisor(
+ common_modulus, abs(int(left_modular_value) - int(right_modular_value))
+ )
+ if new_modulus == "infinity":
+ # The only way for the new_modulus to come out as "infinity" *should* be
+ # if both if_true and if_false have the same constant value.
+ assert left_modular_value == right_modular_value
+ assert left_modulus == right_modulus == "infinity"
+ return new_modulus, left_modular_value
+ else:
+ assert (
+ int(left_modular_value) % new_modulus
+ == int(right_modular_value) % new_modulus
+ )
+ return new_modulus, int(left_modular_value) % new_modulus
def _compute_constraints_of_choice_operator(expression):
- """Computes the constraints of a choice operation '?:'."""
- condition, if_true, if_false = ir_data_utils.reader(expression).function.args
- expression = ir_data_utils.builder(expression)
- if condition.type.boolean.HasField("value"):
- # The generated expressions for $size_in_bits and $size_in_bytes look like
- #
- # $max((field1_existence_condition ? field1_start + field1_size : 0),
- # (field2_existence_condition ? field2_start + field2_size : 0),
- # (field3_existence_condition ? field3_start + field3_size : 0),
- # ...)
- #
- # Since most existence_conditions are just "true", it is important to select
- # the tighter bounds in those cases -- otherwise, only zero-length
- # structures could have a constant $size_in_bits or $size_in_bytes.
- side = if_true if condition.type.boolean.value else if_false
- expression.type.CopyFrom(side.type)
- return
- # The type.integer minimum_value/maximum_value bounding code is needed since
- # constraints.check_constraints() will complain if minimum and maximum are not
- # set correctly. I'm (bolms@) not sure if the modulus/modular_value pulls its
- # weight, but for completeness I've left it in.
- if if_true.type.WhichOneof("type") == "integer":
- # The minimum value of the choice is the minimum value of either side, and
- # the maximum is the maximum value of either side.
- expression.type.integer.minimum_value = str(_min([
- if_true.type.integer.minimum_value,
- if_false.type.integer.minimum_value]))
- expression.type.integer.maximum_value = str(_max([
- if_true.type.integer.maximum_value,
- if_false.type.integer.maximum_value]))
- new_modulus, new_modular_value = _shared_modular_value(
- (if_true.type.integer.modulus, if_true.type.integer.modular_value),
- (if_false.type.integer.modulus, if_false.type.integer.modular_value))
- expression.type.integer.modulus = str(new_modulus)
- expression.type.integer.modular_value = str(new_modular_value)
- else:
- assert if_true.type.WhichOneof("type") in ("boolean", "enumeration"), (
- "Unknown type {} for expression".format(
- if_true.type.WhichOneof("type")))
+ """Computes the constraints of a choice operation '?:'."""
+ condition, if_true, if_false = ir_data_utils.reader(expression).function.args
+ expression = ir_data_utils.builder(expression)
+ if condition.type.boolean.HasField("value"):
+ # The generated expressions for $size_in_bits and $size_in_bytes look like
+ #
+ # $max((field1_existence_condition ? field1_start + field1_size : 0),
+ # (field2_existence_condition ? field2_start + field2_size : 0),
+ # (field3_existence_condition ? field3_start + field3_size : 0),
+ # ...)
+ #
+ # Since most existence_conditions are just "true", it is important to select
+ # the tighter bounds in those cases -- otherwise, only zero-length
+ # structures could have a constant $size_in_bits or $size_in_bytes.
+ side = if_true if condition.type.boolean.value else if_false
+ expression.type.CopyFrom(side.type)
+ return
+ # The type.integer minimum_value/maximum_value bounding code is needed since
+ # constraints.check_constraints() will complain if minimum and maximum are not
+ # set correctly. I'm (bolms@) not sure if the modulus/modular_value pulls its
+ # weight, but for completeness I've left it in.
+ if if_true.type.WhichOneof("type") == "integer":
+ # The minimum value of the choice is the minimum value of either side, and
+ # the maximum is the maximum value of either side.
+ expression.type.integer.minimum_value = str(
+ _min(
+ [
+ if_true.type.integer.minimum_value,
+ if_false.type.integer.minimum_value,
+ ]
+ )
+ )
+ expression.type.integer.maximum_value = str(
+ _max(
+ [
+ if_true.type.integer.maximum_value,
+ if_false.type.integer.maximum_value,
+ ]
+ )
+ )
+ new_modulus, new_modular_value = _shared_modular_value(
+ (if_true.type.integer.modulus, if_true.type.integer.modular_value),
+ (if_false.type.integer.modulus, if_false.type.integer.modular_value),
+ )
+ expression.type.integer.modulus = str(new_modulus)
+ expression.type.integer.modular_value = str(new_modular_value)
+ else:
+ assert if_true.type.WhichOneof("type") in (
+ "boolean",
+ "enumeration",
+ ), "Unknown type {} for expression".format(if_true.type.WhichOneof("type"))
def _greatest_common_divisor(a, b):
- """Returns the greatest common divisor of a and b.
+ """Returns the greatest common divisor of a and b.
- Arguments:
- a: an integer, a stringified integer, or the string "infinity"
- b: an integer, a stringified integer, or the string "infinity"
+ Arguments:
+ a: an integer, a stringified integer, or the string "infinity"
+ b: an integer, a stringified integer, or the string "infinity"
- Returns:
- Conceptually, "infinity" is treated as the product of all integers.
+ Returns:
+ Conceptually, "infinity" is treated as the product of all integers.
- If both a and b are 0, returns "infinity".
+ If both a and b are 0, returns "infinity".
- Otherwise, if either a or b are "infinity", and the other is 0, returns
- "infinity".
+ Otherwise, if either a or b are "infinity", and the other is 0, returns
+ "infinity".
- Otherwise, if either a or b are "infinity", returns the other.
+ Otherwise, if either a or b are "infinity", returns the other.
- Otherwise, returns the greatest common divisor of a and b.
- """
- if a != "infinity": a = int(a)
- if b != "infinity": b = int(b)
- assert a == "infinity" or a >= 0
- assert b == "infinity" or b >= 0
- if a == b == 0: return "infinity"
- # GCD(0, x) is always x, so it's safe to shortcut when a == 0 or b == 0.
- if a == 0: return b
- if b == 0: return a
- if a == "infinity": return b
- if b == "infinity": return a
- return _math_gcd(a, b)
+ Otherwise, returns the greatest common divisor of a and b.
+ """
+ if a != "infinity":
+ a = int(a)
+ if b != "infinity":
+ b = int(b)
+ assert a == "infinity" or a >= 0
+ assert b == "infinity" or b >= 0
+ if a == b == 0:
+ return "infinity"
+ # GCD(0, x) is always x, so it's safe to shortcut when a == 0 or b == 0.
+ if a == 0:
+ return b
+ if b == 0:
+ return a
+ if a == "infinity":
+ return b
+ if b == "infinity":
+ return a
+ return _math_gcd(a, b)
def compute_constants(ir):
- """Computes constant values for all expressions in ir.
+ """Computes constant values for all expressions in ir.
- compute_constants calculates all constant values and adds them to the type
- information for each expression and subexpression.
+ compute_constants calculates all constant values and adds them to the type
+ information for each expression and subexpression.
- Arguments:
- ir: an IR on which to compute constants
+ Arguments:
+ ir: an IR on which to compute constants
- Returns:
- A (possibly empty) list of errors.
- """
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Expression], compute_constraints_of_expression,
- skip_descendants_of={ir_data.Expression})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.RuntimeParameter], _compute_constraints_of_parameter,
- skip_descendants_of={ir_data.Expression})
- return []
+ Returns:
+ A (possibly empty) list of errors.
+ """
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Expression],
+ compute_constraints_of_expression,
+ skip_descendants_of={ir_data.Expression},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.RuntimeParameter],
+ _compute_constraints_of_parameter,
+ skip_descendants_of={ir_data.Expression},
+ )
+ return []
diff --git a/compiler/front_end/expression_bounds_test.py b/compiler/front_end/expression_bounds_test.py
index 54fa0ce..7af6836 100644
--- a/compiler/front_end/expression_bounds_test.py
+++ b/compiler/front_end/expression_bounds_test.py
@@ -22,1105 +22,1204 @@
class ComputeConstantsTest(unittest.TestCase):
- def _make_ir(self, emb_text):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- "m.emb",
- test_util.dict_file_reader({"m.emb": emb_text}),
- stop_before_step="compute_constants")
- assert not errors, errors
- return ir
+ def _make_ir(self, emb_text):
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ "m.emb",
+ test_util.dict_file_reader({"m.emb": emb_text}),
+ stop_before_step="compute_constants",
+ )
+ assert not errors, errors
+ return ir
- def test_constant_integer(self):
- ir = self._make_ir("struct Foo:\n"
- " 10 [+1] UInt x\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- start = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual("10", start.type.integer.minimum_value)
- self.assertEqual("10", start.type.integer.maximum_value)
- self.assertEqual("10", start.type.integer.modular_value)
- self.assertEqual("infinity", start.type.integer.modulus)
+ def test_constant_integer(self):
+ ir = self._make_ir("struct Foo:\n" " 10 [+1] UInt x\n")
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ start = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual("10", start.type.integer.minimum_value)
+ self.assertEqual("10", start.type.integer.maximum_value)
+ self.assertEqual("10", start.type.integer.modular_value)
+ self.assertEqual("infinity", start.type.integer.modulus)
- def test_boolean_constant(self):
- ir = self._make_ir("struct Foo:\n"
- " if true:\n"
- " 0 [+1] UInt x\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expression = ir.module[0].type[0].structure.field[0].existence_condition
- self.assertTrue(expression.type.boolean.HasField("value"))
- self.assertTrue(expression.type.boolean.value)
+ def test_boolean_constant(self):
+ ir = self._make_ir("struct Foo:\n" " if true:\n" " 0 [+1] UInt x\n")
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expression = ir.module[0].type[0].structure.field[0].existence_condition
+ self.assertTrue(expression.type.boolean.HasField("value"))
+ self.assertTrue(expression.type.boolean.value)
- def test_constant_equality(self):
- ir = self._make_ir("struct Foo:\n"
- " if 5 == 5:\n"
- " 0 [+1] UInt x\n"
- " if 5 == 6:\n"
- " 0 [+1] UInt y\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- structure = ir.module[0].type[0].structure
- true_condition = structure.field[0].existence_condition
- false_condition = structure.field[1].existence_condition
- self.assertTrue(true_condition.type.boolean.HasField("value"))
- self.assertTrue(true_condition.type.boolean.value)
- self.assertTrue(false_condition.type.boolean.HasField("value"))
- self.assertFalse(false_condition.type.boolean.value)
+ def test_constant_equality(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if 5 == 5:\n"
+ " 0 [+1] UInt x\n"
+ " if 5 == 6:\n"
+ " 0 [+1] UInt y\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ structure = ir.module[0].type[0].structure
+ true_condition = structure.field[0].existence_condition
+ false_condition = structure.field[1].existence_condition
+ self.assertTrue(true_condition.type.boolean.HasField("value"))
+ self.assertTrue(true_condition.type.boolean.value)
+ self.assertTrue(false_condition.type.boolean.HasField("value"))
+ self.assertFalse(false_condition.type.boolean.value)
- def test_constant_inequality(self):
- ir = self._make_ir("struct Foo:\n"
- " if 5 != 5:\n"
- " 0 [+1] UInt x\n"
- " if 5 != 6:\n"
- " 0 [+1] UInt y\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- structure = ir.module[0].type[0].structure
- false_condition = structure.field[0].existence_condition
- true_condition = structure.field[1].existence_condition
- self.assertTrue(false_condition.type.boolean.HasField("value"))
- self.assertFalse(false_condition.type.boolean.value)
- self.assertTrue(true_condition.type.boolean.HasField("value"))
- self.assertTrue(true_condition.type.boolean.value)
+ def test_constant_inequality(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if 5 != 5:\n"
+ " 0 [+1] UInt x\n"
+ " if 5 != 6:\n"
+ " 0 [+1] UInt y\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ structure = ir.module[0].type[0].structure
+ false_condition = structure.field[0].existence_condition
+ true_condition = structure.field[1].existence_condition
+ self.assertTrue(false_condition.type.boolean.HasField("value"))
+ self.assertFalse(false_condition.type.boolean.value)
+ self.assertTrue(true_condition.type.boolean.HasField("value"))
+ self.assertTrue(true_condition.type.boolean.value)
- def test_constant_less_than(self):
- ir = self._make_ir("struct Foo:\n"
- " if 5 < 4:\n"
- " 0 [+1] UInt x\n"
- " if 5 < 5:\n"
- " 0 [+1] UInt y\n"
- " if 5 < 6:\n"
- " 0 [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- structure = ir.module[0].type[0].structure
- greater_than_condition = structure.field[0].existence_condition
- equal_condition = structure.field[1].existence_condition
- less_than_condition = structure.field[2].existence_condition
- self.assertTrue(greater_than_condition.type.boolean.HasField("value"))
- self.assertFalse(greater_than_condition.type.boolean.value)
- self.assertTrue(equal_condition.type.boolean.HasField("value"))
- self.assertFalse(equal_condition.type.boolean.value)
- self.assertTrue(less_than_condition.type.boolean.HasField("value"))
- self.assertTrue(less_than_condition.type.boolean.value)
+ def test_constant_less_than(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if 5 < 4:\n"
+ " 0 [+1] UInt x\n"
+ " if 5 < 5:\n"
+ " 0 [+1] UInt y\n"
+ " if 5 < 6:\n"
+ " 0 [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ structure = ir.module[0].type[0].structure
+ greater_than_condition = structure.field[0].existence_condition
+ equal_condition = structure.field[1].existence_condition
+ less_than_condition = structure.field[2].existence_condition
+ self.assertTrue(greater_than_condition.type.boolean.HasField("value"))
+ self.assertFalse(greater_than_condition.type.boolean.value)
+ self.assertTrue(equal_condition.type.boolean.HasField("value"))
+ self.assertFalse(equal_condition.type.boolean.value)
+ self.assertTrue(less_than_condition.type.boolean.HasField("value"))
+ self.assertTrue(less_than_condition.type.boolean.value)
- def test_constant_less_than_or_equal(self):
- ir = self._make_ir("struct Foo:\n"
- " if 5 <= 4:\n"
- " 0 [+1] UInt x\n"
- " if 5 <= 5:\n"
- " 0 [+1] UInt y\n"
- " if 5 <= 6:\n"
- " 0 [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- structure = ir.module[0].type[0].structure
- greater_than_condition = structure.field[0].existence_condition
- equal_condition = structure.field[1].existence_condition
- less_than_condition = structure.field[2].existence_condition
- self.assertTrue(greater_than_condition.type.boolean.HasField("value"))
- self.assertFalse(greater_than_condition.type.boolean.value)
- self.assertTrue(equal_condition.type.boolean.HasField("value"))
- self.assertTrue(equal_condition.type.boolean.value)
- self.assertTrue(less_than_condition.type.boolean.HasField("value"))
- self.assertTrue(less_than_condition.type.boolean.value)
+ def test_constant_less_than_or_equal(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if 5 <= 4:\n"
+ " 0 [+1] UInt x\n"
+ " if 5 <= 5:\n"
+ " 0 [+1] UInt y\n"
+ " if 5 <= 6:\n"
+ " 0 [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ structure = ir.module[0].type[0].structure
+ greater_than_condition = structure.field[0].existence_condition
+ equal_condition = structure.field[1].existence_condition
+ less_than_condition = structure.field[2].existence_condition
+ self.assertTrue(greater_than_condition.type.boolean.HasField("value"))
+ self.assertFalse(greater_than_condition.type.boolean.value)
+ self.assertTrue(equal_condition.type.boolean.HasField("value"))
+ self.assertTrue(equal_condition.type.boolean.value)
+ self.assertTrue(less_than_condition.type.boolean.HasField("value"))
+ self.assertTrue(less_than_condition.type.boolean.value)
- def test_constant_greater_than(self):
- ir = self._make_ir("struct Foo:\n"
- " if 5 > 4:\n"
- " 0 [+1] UInt x\n"
- " if 5 > 5:\n"
- " 0 [+1] UInt y\n"
- " if 5 > 6:\n"
- " 0 [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- structure = ir.module[0].type[0].structure
- greater_than_condition = structure.field[0].existence_condition
- equal_condition = structure.field[1].existence_condition
- less_than_condition = structure.field[2].existence_condition
- self.assertTrue(greater_than_condition.type.boolean.HasField("value"))
- self.assertTrue(greater_than_condition.type.boolean.value)
- self.assertTrue(equal_condition.type.boolean.HasField("value"))
- self.assertFalse(equal_condition.type.boolean.value)
- self.assertTrue(less_than_condition.type.boolean.HasField("value"))
- self.assertFalse(less_than_condition.type.boolean.value)
+ def test_constant_greater_than(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if 5 > 4:\n"
+ " 0 [+1] UInt x\n"
+ " if 5 > 5:\n"
+ " 0 [+1] UInt y\n"
+ " if 5 > 6:\n"
+ " 0 [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ structure = ir.module[0].type[0].structure
+ greater_than_condition = structure.field[0].existence_condition
+ equal_condition = structure.field[1].existence_condition
+ less_than_condition = structure.field[2].existence_condition
+ self.assertTrue(greater_than_condition.type.boolean.HasField("value"))
+ self.assertTrue(greater_than_condition.type.boolean.value)
+ self.assertTrue(equal_condition.type.boolean.HasField("value"))
+ self.assertFalse(equal_condition.type.boolean.value)
+ self.assertTrue(less_than_condition.type.boolean.HasField("value"))
+ self.assertFalse(less_than_condition.type.boolean.value)
- def test_constant_greater_than_or_equal(self):
- ir = self._make_ir("struct Foo:\n"
- " if 5 >= 4:\n"
- " 0 [+1] UInt x\n"
- " if 5 >= 5:\n"
- " 0 [+1] UInt y\n"
- " if 5 >= 6:\n"
- " 0 [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- structure = ir.module[0].type[0].structure
- greater_than_condition = structure.field[0].existence_condition
- equal_condition = structure.field[1].existence_condition
- less_than_condition = structure.field[2].existence_condition
- self.assertTrue(greater_than_condition.type.boolean.HasField("value"))
- self.assertTrue(greater_than_condition.type.boolean.value)
- self.assertTrue(equal_condition.type.boolean.HasField("value"))
- self.assertTrue(equal_condition.type.boolean.value)
- self.assertTrue(less_than_condition.type.boolean.HasField("value"))
- self.assertFalse(less_than_condition.type.boolean.value)
+ def test_constant_greater_than_or_equal(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if 5 >= 4:\n"
+ " 0 [+1] UInt x\n"
+ " if 5 >= 5:\n"
+ " 0 [+1] UInt y\n"
+ " if 5 >= 6:\n"
+ " 0 [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ structure = ir.module[0].type[0].structure
+ greater_than_condition = structure.field[0].existence_condition
+ equal_condition = structure.field[1].existence_condition
+ less_than_condition = structure.field[2].existence_condition
+ self.assertTrue(greater_than_condition.type.boolean.HasField("value"))
+ self.assertTrue(greater_than_condition.type.boolean.value)
+ self.assertTrue(equal_condition.type.boolean.HasField("value"))
+ self.assertTrue(equal_condition.type.boolean.value)
+ self.assertTrue(less_than_condition.type.boolean.HasField("value"))
+ self.assertFalse(less_than_condition.type.boolean.value)
- def test_constant_and(self):
- ir = self._make_ir("struct Foo:\n"
- " if false && false:\n"
- " 0 [+1] UInt x\n"
- " if true && false:\n"
- " 0 [+1] UInt y\n"
- " if false && true:\n"
- " 0 [+1] UInt z\n"
- " if true && true:\n"
- " 0 [+1] UInt w\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- structure = ir.module[0].type[0].structure
- false_false_condition = structure.field[0].existence_condition
- true_false_condition = structure.field[1].existence_condition
- false_true_condition = structure.field[2].existence_condition
- true_true_condition = structure.field[3].existence_condition
- self.assertTrue(false_false_condition.type.boolean.HasField("value"))
- self.assertFalse(false_false_condition.type.boolean.value)
- self.assertTrue(true_false_condition.type.boolean.HasField("value"))
- self.assertFalse(true_false_condition.type.boolean.value)
- self.assertTrue(false_true_condition.type.boolean.HasField("value"))
- self.assertFalse(false_true_condition.type.boolean.value)
- self.assertTrue(true_true_condition.type.boolean.HasField("value"))
- self.assertTrue(true_true_condition.type.boolean.value)
+ def test_constant_and(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if false && false:\n"
+ " 0 [+1] UInt x\n"
+ " if true && false:\n"
+ " 0 [+1] UInt y\n"
+ " if false && true:\n"
+ " 0 [+1] UInt z\n"
+ " if true && true:\n"
+ " 0 [+1] UInt w\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ structure = ir.module[0].type[0].structure
+ false_false_condition = structure.field[0].existence_condition
+ true_false_condition = structure.field[1].existence_condition
+ false_true_condition = structure.field[2].existence_condition
+ true_true_condition = structure.field[3].existence_condition
+ self.assertTrue(false_false_condition.type.boolean.HasField("value"))
+ self.assertFalse(false_false_condition.type.boolean.value)
+ self.assertTrue(true_false_condition.type.boolean.HasField("value"))
+ self.assertFalse(true_false_condition.type.boolean.value)
+ self.assertTrue(false_true_condition.type.boolean.HasField("value"))
+ self.assertFalse(false_true_condition.type.boolean.value)
+ self.assertTrue(true_true_condition.type.boolean.HasField("value"))
+ self.assertTrue(true_true_condition.type.boolean.value)
- def test_constant_or(self):
- ir = self._make_ir("struct Foo:\n"
- " if false || false:\n"
- " 0 [+1] UInt x\n"
- " if true || false:\n"
- " 0 [+1] UInt y\n"
- " if false || true:\n"
- " 0 [+1] UInt z\n"
- " if true || true:\n"
- " 0 [+1] UInt w\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- structure = ir.module[0].type[0].structure
- false_false_condition = structure.field[0].existence_condition
- true_false_condition = structure.field[1].existence_condition
- false_true_condition = structure.field[2].existence_condition
- true_true_condition = structure.field[3].existence_condition
- self.assertTrue(false_false_condition.type.boolean.HasField("value"))
- self.assertFalse(false_false_condition.type.boolean.value)
- self.assertTrue(true_false_condition.type.boolean.HasField("value"))
- self.assertTrue(true_false_condition.type.boolean.value)
- self.assertTrue(false_true_condition.type.boolean.HasField("value"))
- self.assertTrue(false_true_condition.type.boolean.value)
- self.assertTrue(true_true_condition.type.boolean.HasField("value"))
- self.assertTrue(true_true_condition.type.boolean.value)
+ def test_constant_or(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if false || false:\n"
+ " 0 [+1] UInt x\n"
+ " if true || false:\n"
+ " 0 [+1] UInt y\n"
+ " if false || true:\n"
+ " 0 [+1] UInt z\n"
+ " if true || true:\n"
+ " 0 [+1] UInt w\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ structure = ir.module[0].type[0].structure
+ false_false_condition = structure.field[0].existence_condition
+ true_false_condition = structure.field[1].existence_condition
+ false_true_condition = structure.field[2].existence_condition
+ true_true_condition = structure.field[3].existence_condition
+ self.assertTrue(false_false_condition.type.boolean.HasField("value"))
+ self.assertFalse(false_false_condition.type.boolean.value)
+ self.assertTrue(true_false_condition.type.boolean.HasField("value"))
+ self.assertTrue(true_false_condition.type.boolean.value)
+ self.assertTrue(false_true_condition.type.boolean.HasField("value"))
+ self.assertTrue(false_true_condition.type.boolean.value)
+ self.assertTrue(true_true_condition.type.boolean.HasField("value"))
+ self.assertTrue(true_true_condition.type.boolean.value)
- def test_enum_constant(self):
- ir = self._make_ir("struct Foo:\n"
- " if Bar.QUX == Bar.QUX:\n"
- " 0 [+1] Bar x\n"
- "enum Bar:\n"
- " QUX = 12\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- condition = ir.module[0].type[0].structure.field[0].existence_condition
- left = condition.function.args[0]
- self.assertEqual("12", left.type.enumeration.value)
+ def test_enum_constant(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if Bar.QUX == Bar.QUX:\n"
+ " 0 [+1] Bar x\n"
+ "enum Bar:\n"
+ " QUX = 12\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ condition = ir.module[0].type[0].structure.field[0].existence_condition
+ left = condition.function.args[0]
+ self.assertEqual("12", left.type.enumeration.value)
- def test_non_constant_field_reference(self):
- ir = self._make_ir("struct Foo:\n"
- " y [+1] UInt x\n"
- " 0 [+1] UInt y\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- start = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual("0", start.type.integer.minimum_value)
- self.assertEqual("255", start.type.integer.maximum_value)
- self.assertEqual("0", start.type.integer.modular_value)
- self.assertEqual("1", start.type.integer.modulus)
+ def test_non_constant_field_reference(self):
+ ir = self._make_ir("struct Foo:\n" " y [+1] UInt x\n" " 0 [+1] UInt y\n")
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ start = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual("0", start.type.integer.minimum_value)
+ self.assertEqual("255", start.type.integer.maximum_value)
+ self.assertEqual("0", start.type.integer.modular_value)
+ self.assertEqual("1", start.type.integer.modulus)
- def test_field_reference_bounds_are_uncomputable(self):
- # Variable-sized UInt/Int/Bcd should not cause an error here: they are
- # handled in the constraints pass.
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 0 [+x] UInt y\n"
- " y [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
+ def test_field_reference_bounds_are_uncomputable(self):
+ # Variable-sized UInt/Int/Bcd should not cause an error here: they are
+ # handled in the constraints pass.
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 0 [+x] UInt y\n"
+ " y [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
- def test_field_references_references_bounds_are_uncomputable(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 0 [+x] UInt y\n"
- " 0 [+y] UInt z\n"
- " z [+1] UInt q\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
+ def test_field_references_references_bounds_are_uncomputable(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 0 [+x] UInt y\n"
+ " 0 [+y] UInt z\n"
+ " z [+1] UInt q\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
- def test_non_constant_equality(self):
- ir = self._make_ir("struct Foo:\n"
- " if 5 == y:\n"
- " 0 [+1] UInt x\n"
- " 0 [+1] UInt y\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- structure = ir.module[0].type[0].structure
- condition = structure.field[0].existence_condition
- self.assertFalse(condition.type.boolean.HasField("value"))
+ def test_non_constant_equality(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if 5 == y:\n"
+ " 0 [+1] UInt x\n"
+ " 0 [+1] UInt y\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ structure = ir.module[0].type[0].structure
+ condition = structure.field[0].existence_condition
+ self.assertFalse(condition.type.boolean.HasField("value"))
- def test_constant_addition(self):
- ir = self._make_ir("struct Foo:\n"
- " 7+5 [+1] UInt x\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- start = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual("12", start.type.integer.minimum_value)
- self.assertEqual("12", start.type.integer.maximum_value)
- self.assertEqual("12", start.type.integer.modular_value)
- self.assertEqual("infinity", start.type.integer.modulus)
- self.assertEqual("7", start.function.args[0].type.integer.minimum_value)
- self.assertEqual("7", start.function.args[0].type.integer.maximum_value)
- self.assertEqual("7", start.function.args[0].type.integer.modular_value)
- self.assertEqual("infinity", start.type.integer.modulus)
- self.assertEqual("5", start.function.args[1].type.integer.minimum_value)
- self.assertEqual("5", start.function.args[1].type.integer.maximum_value)
- self.assertEqual("5", start.function.args[1].type.integer.modular_value)
- self.assertEqual("infinity", start.type.integer.modulus)
+ def test_constant_addition(self):
+ ir = self._make_ir("struct Foo:\n" " 7+5 [+1] UInt x\n")
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ start = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual("12", start.type.integer.minimum_value)
+ self.assertEqual("12", start.type.integer.maximum_value)
+ self.assertEqual("12", start.type.integer.modular_value)
+ self.assertEqual("infinity", start.type.integer.modulus)
+ self.assertEqual("7", start.function.args[0].type.integer.minimum_value)
+ self.assertEqual("7", start.function.args[0].type.integer.maximum_value)
+ self.assertEqual("7", start.function.args[0].type.integer.modular_value)
+ self.assertEqual("infinity", start.type.integer.modulus)
+ self.assertEqual("5", start.function.args[1].type.integer.minimum_value)
+ self.assertEqual("5", start.function.args[1].type.integer.maximum_value)
+ self.assertEqual("5", start.function.args[1].type.integer.modular_value)
+ self.assertEqual("infinity", start.type.integer.modulus)
- def test_constant_subtraction(self):
- ir = self._make_ir("struct Foo:\n"
- " 7-5 [+1] UInt x\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- start = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual("2", start.type.integer.minimum_value)
- self.assertEqual("2", start.type.integer.maximum_value)
- self.assertEqual("2", start.type.integer.modular_value)
- self.assertEqual("infinity", start.type.integer.modulus)
- self.assertEqual("7", start.function.args[0].type.integer.minimum_value)
- self.assertEqual("7", start.function.args[0].type.integer.maximum_value)
- self.assertEqual("7", start.function.args[0].type.integer.modular_value)
- self.assertEqual("infinity", start.type.integer.modulus)
- self.assertEqual("5", start.function.args[1].type.integer.minimum_value)
- self.assertEqual("5", start.function.args[1].type.integer.maximum_value)
- self.assertEqual("5", start.function.args[1].type.integer.modular_value)
- self.assertEqual("infinity", start.type.integer.modulus)
+ def test_constant_subtraction(self):
+ ir = self._make_ir("struct Foo:\n" " 7-5 [+1] UInt x\n")
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ start = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual("2", start.type.integer.minimum_value)
+ self.assertEqual("2", start.type.integer.maximum_value)
+ self.assertEqual("2", start.type.integer.modular_value)
+ self.assertEqual("infinity", start.type.integer.modulus)
+ self.assertEqual("7", start.function.args[0].type.integer.minimum_value)
+ self.assertEqual("7", start.function.args[0].type.integer.maximum_value)
+ self.assertEqual("7", start.function.args[0].type.integer.modular_value)
+ self.assertEqual("infinity", start.type.integer.modulus)
+ self.assertEqual("5", start.function.args[1].type.integer.minimum_value)
+ self.assertEqual("5", start.function.args[1].type.integer.maximum_value)
+ self.assertEqual("5", start.function.args[1].type.integer.modular_value)
+ self.assertEqual("infinity", start.type.integer.modulus)
- def test_constant_multiplication(self):
- ir = self._make_ir("struct Foo:\n"
- " 7*5 [+1] UInt x\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- start = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual("35", start.type.integer.minimum_value)
- self.assertEqual("35", start.type.integer.maximum_value)
- self.assertEqual("35", start.type.integer.modular_value)
- self.assertEqual("infinity", start.type.integer.modulus)
- self.assertEqual("7", start.function.args[0].type.integer.minimum_value)
- self.assertEqual("7", start.function.args[0].type.integer.maximum_value)
- self.assertEqual("7", start.function.args[0].type.integer.modular_value)
- self.assertEqual("infinity", start.type.integer.modulus)
- self.assertEqual("5", start.function.args[1].type.integer.minimum_value)
- self.assertEqual("5", start.function.args[1].type.integer.maximum_value)
- self.assertEqual("5", start.function.args[1].type.integer.modular_value)
- self.assertEqual("infinity", start.type.integer.modulus)
+ def test_constant_multiplication(self):
+ ir = self._make_ir("struct Foo:\n" " 7*5 [+1] UInt x\n")
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ start = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual("35", start.type.integer.minimum_value)
+ self.assertEqual("35", start.type.integer.maximum_value)
+ self.assertEqual("35", start.type.integer.modular_value)
+ self.assertEqual("infinity", start.type.integer.modulus)
+ self.assertEqual("7", start.function.args[0].type.integer.minimum_value)
+ self.assertEqual("7", start.function.args[0].type.integer.maximum_value)
+ self.assertEqual("7", start.function.args[0].type.integer.modular_value)
+ self.assertEqual("infinity", start.type.integer.modulus)
+ self.assertEqual("5", start.function.args[1].type.integer.minimum_value)
+ self.assertEqual("5", start.function.args[1].type.integer.maximum_value)
+ self.assertEqual("5", start.function.args[1].type.integer.modular_value)
+ self.assertEqual("infinity", start.type.integer.modulus)
- def test_nested_constant_expression(self):
- ir = self._make_ir("struct Foo:\n"
- " if 7*(3+1) == 28:\n"
- " 0 [+1] UInt x\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- condition = ir.module[0].type[0].structure.field[0].existence_condition
- self.assertTrue(condition.type.boolean.value)
- condition_left = condition.function.args[0]
- self.assertEqual("28", condition_left.type.integer.minimum_value)
- self.assertEqual("28", condition_left.type.integer.maximum_value)
- self.assertEqual("28", condition_left.type.integer.modular_value)
- self.assertEqual("infinity", condition_left.type.integer.modulus)
- condition_left_left = condition_left.function.args[0]
- self.assertEqual("7", condition_left_left.type.integer.minimum_value)
- self.assertEqual("7", condition_left_left.type.integer.maximum_value)
- self.assertEqual("7", condition_left_left.type.integer.modular_value)
- self.assertEqual("infinity", condition_left_left.type.integer.modulus)
- condition_left_right = condition_left.function.args[1]
- self.assertEqual("4", condition_left_right.type.integer.minimum_value)
- self.assertEqual("4", condition_left_right.type.integer.maximum_value)
- self.assertEqual("4", condition_left_right.type.integer.modular_value)
- self.assertEqual("infinity", condition_left_right.type.integer.modulus)
- condition_left_right_left = condition_left_right.function.args[0]
- self.assertEqual("3", condition_left_right_left.type.integer.minimum_value)
- self.assertEqual("3", condition_left_right_left.type.integer.maximum_value)
- self.assertEqual("3", condition_left_right_left.type.integer.modular_value)
- self.assertEqual("infinity", condition_left_right_left.type.integer.modulus)
- condition_left_right_right = condition_left_right.function.args[1]
- self.assertEqual("1", condition_left_right_right.type.integer.minimum_value)
- self.assertEqual("1", condition_left_right_right.type.integer.maximum_value)
- self.assertEqual("1", condition_left_right_right.type.integer.modular_value)
- self.assertEqual("infinity",
- condition_left_right_right.type.integer.modulus)
+ def test_nested_constant_expression(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " if 7*(3+1) == 28:\n" " 0 [+1] UInt x\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ condition = ir.module[0].type[0].structure.field[0].existence_condition
+ self.assertTrue(condition.type.boolean.value)
+ condition_left = condition.function.args[0]
+ self.assertEqual("28", condition_left.type.integer.minimum_value)
+ self.assertEqual("28", condition_left.type.integer.maximum_value)
+ self.assertEqual("28", condition_left.type.integer.modular_value)
+ self.assertEqual("infinity", condition_left.type.integer.modulus)
+ condition_left_left = condition_left.function.args[0]
+ self.assertEqual("7", condition_left_left.type.integer.minimum_value)
+ self.assertEqual("7", condition_left_left.type.integer.maximum_value)
+ self.assertEqual("7", condition_left_left.type.integer.modular_value)
+ self.assertEqual("infinity", condition_left_left.type.integer.modulus)
+ condition_left_right = condition_left.function.args[1]
+ self.assertEqual("4", condition_left_right.type.integer.minimum_value)
+ self.assertEqual("4", condition_left_right.type.integer.maximum_value)
+ self.assertEqual("4", condition_left_right.type.integer.modular_value)
+ self.assertEqual("infinity", condition_left_right.type.integer.modulus)
+ condition_left_right_left = condition_left_right.function.args[0]
+ self.assertEqual("3", condition_left_right_left.type.integer.minimum_value)
+ self.assertEqual("3", condition_left_right_left.type.integer.maximum_value)
+ self.assertEqual("3", condition_left_right_left.type.integer.modular_value)
+ self.assertEqual("infinity", condition_left_right_left.type.integer.modulus)
+ condition_left_right_right = condition_left_right.function.args[1]
+ self.assertEqual("1", condition_left_right_right.type.integer.minimum_value)
+ self.assertEqual("1", condition_left_right_right.type.integer.maximum_value)
+ self.assertEqual("1", condition_left_right_right.type.integer.modular_value)
+ self.assertEqual("infinity", condition_left_right_right.type.integer.modulus)
- def test_constant_plus_non_constant(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 5+(4*x) [+1] UInt y\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- y_start = ir.module[0].type[0].structure.field[1].location.start
- self.assertEqual("4", y_start.type.integer.modulus)
- self.assertEqual("1", y_start.type.integer.modular_value)
- self.assertEqual("5", y_start.type.integer.minimum_value)
- self.assertEqual("1025", y_start.type.integer.maximum_value)
+ def test_constant_plus_non_constant(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " 5+(4*x) [+1] UInt y\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ y_start = ir.module[0].type[0].structure.field[1].location.start
+ self.assertEqual("4", y_start.type.integer.modulus)
+ self.assertEqual("1", y_start.type.integer.modular_value)
+ self.assertEqual("5", y_start.type.integer.minimum_value)
+ self.assertEqual("1025", y_start.type.integer.maximum_value)
- def test_constant_minus_non_constant(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 5-(4*x) [+1] UInt y\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- y_start = ir.module[0].type[0].structure.field[1].location.start
- self.assertEqual("4", y_start.type.integer.modulus)
- self.assertEqual("1", y_start.type.integer.modular_value)
- self.assertEqual("-1015", y_start.type.integer.minimum_value)
- self.assertEqual("5", y_start.type.integer.maximum_value)
+ def test_constant_minus_non_constant(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " 5-(4*x) [+1] UInt y\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ y_start = ir.module[0].type[0].structure.field[1].location.start
+ self.assertEqual("4", y_start.type.integer.modulus)
+ self.assertEqual("1", y_start.type.integer.modular_value)
+ self.assertEqual("-1015", y_start.type.integer.minimum_value)
+ self.assertEqual("5", y_start.type.integer.maximum_value)
- def test_non_constant_minus_constant(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " (4*x)-5 [+1] UInt y\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- y_start = ir.module[0].type[0].structure.field[1].location.start
- self.assertEqual(str((4 * 0) - 5), y_start.type.integer.minimum_value)
- self.assertEqual(str((4 * 255) - 5), y_start.type.integer.maximum_value)
- self.assertEqual("4", y_start.type.integer.modulus)
- self.assertEqual("3", y_start.type.integer.modular_value)
+ def test_non_constant_minus_constant(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " (4*x)-5 [+1] UInt y\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ y_start = ir.module[0].type[0].structure.field[1].location.start
+ self.assertEqual(str((4 * 0) - 5), y_start.type.integer.minimum_value)
+ self.assertEqual(str((4 * 255) - 5), y_start.type.integer.maximum_value)
+ self.assertEqual("4", y_start.type.integer.modulus)
+ self.assertEqual("3", y_start.type.integer.modular_value)
- def test_non_constant_plus_non_constant(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] UInt y\n"
- " (4*x)+(6*y+3) [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("3", z_start.type.integer.minimum_value)
- self.assertEqual(str(4 * 255 + 6 * 255 + 3),
- z_start.type.integer.maximum_value)
- self.assertEqual("2", z_start.type.integer.modulus)
- self.assertEqual("1", z_start.type.integer.modular_value)
+ def test_non_constant_plus_non_constant(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] UInt y\n"
+ " (4*x)+(6*y+3) [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("3", z_start.type.integer.minimum_value)
+ self.assertEqual(str(4 * 255 + 6 * 255 + 3), z_start.type.integer.maximum_value)
+ self.assertEqual("2", z_start.type.integer.modulus)
+ self.assertEqual("1", z_start.type.integer.modular_value)
- def test_non_constant_minus_non_constant(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] UInt y\n"
- " (x*3)-(y*3) [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("3", z_start.type.integer.modulus)
- self.assertEqual("0", z_start.type.integer.modular_value)
- self.assertEqual(str(-3 * 255), z_start.type.integer.minimum_value)
- self.assertEqual(str(3 * 255), z_start.type.integer.maximum_value)
+ def test_non_constant_minus_non_constant(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] UInt y\n"
+ " (x*3)-(y*3) [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("3", z_start.type.integer.modulus)
+ self.assertEqual("0", z_start.type.integer.modular_value)
+ self.assertEqual(str(-3 * 255), z_start.type.integer.minimum_value)
+ self.assertEqual(str(3 * 255), z_start.type.integer.maximum_value)
- def test_non_constant_times_constant(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " (4*x+1)*5 [+1] UInt y\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- y_start = ir.module[0].type[0].structure.field[1].location.start
- self.assertEqual("20", y_start.type.integer.modulus)
- self.assertEqual("5", y_start.type.integer.modular_value)
- self.assertEqual("5", y_start.type.integer.minimum_value)
- self.assertEqual(str((4 * 255 + 1) * 5), y_start.type.integer.maximum_value)
+ def test_non_constant_times_constant(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " (4*x+1)*5 [+1] UInt y\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ y_start = ir.module[0].type[0].structure.field[1].location.start
+ self.assertEqual("20", y_start.type.integer.modulus)
+ self.assertEqual("5", y_start.type.integer.modular_value)
+ self.assertEqual("5", y_start.type.integer.minimum_value)
+ self.assertEqual(str((4 * 255 + 1) * 5), y_start.type.integer.maximum_value)
- def test_non_constant_times_negative_constant(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " (4*x+1)*-5 [+1] UInt y\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- y_start = ir.module[0].type[0].structure.field[1].location.start
- self.assertEqual("20", y_start.type.integer.modulus)
- self.assertEqual("15", y_start.type.integer.modular_value)
- self.assertEqual(str((4 * 255 + 1) * -5),
- y_start.type.integer.minimum_value)
- self.assertEqual("-5", y_start.type.integer.maximum_value)
+ def test_non_constant_times_negative_constant(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " (4*x+1)*-5 [+1] UInt y\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ y_start = ir.module[0].type[0].structure.field[1].location.start
+ self.assertEqual("20", y_start.type.integer.modulus)
+ self.assertEqual("15", y_start.type.integer.modular_value)
+ self.assertEqual(str((4 * 255 + 1) * -5), y_start.type.integer.minimum_value)
+ self.assertEqual("-5", y_start.type.integer.maximum_value)
- def test_non_constant_times_zero(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " (4*x+1)*0 [+1] UInt y\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- y_start = ir.module[0].type[0].structure.field[1].location.start
- self.assertEqual("infinity", y_start.type.integer.modulus)
- self.assertEqual("0", y_start.type.integer.modular_value)
- self.assertEqual("0", y_start.type.integer.minimum_value)
- self.assertEqual("0", y_start.type.integer.maximum_value)
+ def test_non_constant_times_zero(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " (4*x+1)*0 [+1] UInt y\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ y_start = ir.module[0].type[0].structure.field[1].location.start
+ self.assertEqual("infinity", y_start.type.integer.modulus)
+ self.assertEqual("0", y_start.type.integer.modular_value)
+ self.assertEqual("0", y_start.type.integer.minimum_value)
+ self.assertEqual("0", y_start.type.integer.maximum_value)
- def test_non_constant_times_non_constant_shared_modulus(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] UInt y\n"
- " (4*x+3)*(4*y+3) [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("4", z_start.type.integer.modulus)
- self.assertEqual("1", z_start.type.integer.modular_value)
- self.assertEqual("9", z_start.type.integer.minimum_value)
- self.assertEqual(str((4 * 255 + 3)**2), z_start.type.integer.maximum_value)
+ def test_non_constant_times_non_constant_shared_modulus(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] UInt y\n"
+ " (4*x+3)*(4*y+3) [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("4", z_start.type.integer.modulus)
+ self.assertEqual("1", z_start.type.integer.modular_value)
+ self.assertEqual("9", z_start.type.integer.minimum_value)
+ self.assertEqual(str((4 * 255 + 3) ** 2), z_start.type.integer.maximum_value)
- def test_non_constant_times_non_constant_congruent_to_zero(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] UInt y\n"
- " (4*x)*(4*y) [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("16", z_start.type.integer.modulus)
- self.assertEqual("0", z_start.type.integer.modular_value)
- self.assertEqual("0", z_start.type.integer.minimum_value)
- self.assertEqual(str((4 * 255)**2), z_start.type.integer.maximum_value)
+ def test_non_constant_times_non_constant_congruent_to_zero(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] UInt y\n"
+ " (4*x)*(4*y) [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("16", z_start.type.integer.modulus)
+ self.assertEqual("0", z_start.type.integer.modular_value)
+ self.assertEqual("0", z_start.type.integer.minimum_value)
+ self.assertEqual(str((4 * 255) ** 2), z_start.type.integer.maximum_value)
- def test_non_constant_times_non_constant_partially_shared_modulus(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] UInt y\n"
- " (4*x+3)*(8*y+3) [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("4", z_start.type.integer.modulus)
- self.assertEqual("1", z_start.type.integer.modular_value)
- self.assertEqual("9", z_start.type.integer.minimum_value)
- self.assertEqual(str((4 * 255 + 3) * (8 * 255 + 3)),
- z_start.type.integer.maximum_value)
+ def test_non_constant_times_non_constant_partially_shared_modulus(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] UInt y\n"
+ " (4*x+3)*(8*y+3) [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("4", z_start.type.integer.modulus)
+ self.assertEqual("1", z_start.type.integer.modular_value)
+ self.assertEqual("9", z_start.type.integer.minimum_value)
+ self.assertEqual(
+ str((4 * 255 + 3) * (8 * 255 + 3)), z_start.type.integer.maximum_value
+ )
- def test_non_constant_times_non_constant_full_complexity(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] UInt y\n"
- " (12*x+9)*(40*y+15) [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("60", z_start.type.integer.modulus)
- self.assertEqual("15", z_start.type.integer.modular_value)
- self.assertEqual(str(9 * 15), z_start.type.integer.minimum_value)
- self.assertEqual(str((12 * 255 + 9) * (40 * 255 + 15)),
- z_start.type.integer.maximum_value)
+ def test_non_constant_times_non_constant_full_complexity(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] UInt y\n"
+ " (12*x+9)*(40*y+15) [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("60", z_start.type.integer.modulus)
+ self.assertEqual("15", z_start.type.integer.modular_value)
+ self.assertEqual(str(9 * 15), z_start.type.integer.minimum_value)
+ self.assertEqual(
+ str((12 * 255 + 9) * (40 * 255 + 15)), z_start.type.integer.maximum_value
+ )
- def test_signed_non_constant_times_signed_non_constant_full_complexity(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] Int x\n"
- " 1 [+1] Int y\n"
- " (12*x+9)*(40*y+15) [+1] Int z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("60", z_start.type.integer.modulus)
- self.assertEqual("15", z_start.type.integer.modular_value)
- # Max x/min y is slightly lower than min x/max y (-7825965 vs -7780065).
- self.assertEqual(str((12 * 127 + 9) * (40 * -128 + 15)),
- z_start.type.integer.minimum_value)
- # Max x/max y is slightly higher than min x/min y (7810635 vs 7795335).
- self.assertEqual(str((12 * 127 + 9) * (40 * 127 + 15)),
- z_start.type.integer.maximum_value)
+ def test_signed_non_constant_times_signed_non_constant_full_complexity(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] Int x\n"
+ " 1 [+1] Int y\n"
+ " (12*x+9)*(40*y+15) [+1] Int z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("60", z_start.type.integer.modulus)
+ self.assertEqual("15", z_start.type.integer.modular_value)
+ # Max x/min y is slightly lower than min x/max y (-7825965 vs -7780065).
+ self.assertEqual(
+ str((12 * 127 + 9) * (40 * -128 + 15)), z_start.type.integer.minimum_value
+ )
+ # Max x/max y is slightly higher than min x/min y (7810635 vs 7795335).
+ self.assertEqual(
+ str((12 * 127 + 9) * (40 * 127 + 15)), z_start.type.integer.maximum_value
+ )
- def test_non_constant_times_non_constant_flipped_min_max(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] UInt y\n"
- " (-x*3)*(y*3) [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("9", z_start.type.integer.modulus)
- self.assertEqual("0", z_start.type.integer.modular_value)
- self.assertEqual(str(-((3 * 255)**2)), z_start.type.integer.minimum_value)
- self.assertEqual("0", z_start.type.integer.maximum_value)
+ def test_non_constant_times_non_constant_flipped_min_max(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] UInt y\n"
+ " (-x*3)*(y*3) [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("9", z_start.type.integer.modulus)
+ self.assertEqual("0", z_start.type.integer.modular_value)
+ self.assertEqual(str(-((3 * 255) ** 2)), z_start.type.integer.minimum_value)
+ self.assertEqual("0", z_start.type.integer.maximum_value)
- # Currently, only `$static_size_in_bits` has an infinite bound, so all of the
- # examples below use `$static_size_in_bits`. Unfortunately, this also means
- # that these tests rely on the fact that Emboss doesn't try to do any term
- # rewriting or smart correlation between the arguments of various operators:
- # for example, several tests rely on `$static_size_in_bits -
- # $static_size_in_bits` having the range `-infinity` to `infinity`, when a
- # trivial term rewrite would turn that expression into `0`.
- #
- # Unbounded expressions are only allowed at compile-time anyway, so these
- # tests cover some fairly unlikely uses of the Emboss expression language.
- def test_unbounded_plus_constant(self):
- ir = self._make_ir("external Foo:\n"
- " [requires: $static_size_in_bits + 2 > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("1", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("2", expr.type.integer.minimum_value)
- self.assertEqual("infinity", expr.type.integer.maximum_value)
+ # Currently, only `$static_size_in_bits` has an infinite bound, so all of the
+ # examples below use `$static_size_in_bits`. Unfortunately, this also means
+ # that these tests rely on the fact that Emboss doesn't try to do any term
+ # rewriting or smart correlation between the arguments of various operators:
+ # for example, several tests rely on `$static_size_in_bits -
+ # $static_size_in_bits` having the range `-infinity` to `infinity`, when a
+ # trivial term rewrite would turn that expression into `0`.
+ #
+ # Unbounded expressions are only allowed at compile-time anyway, so these
+ # tests cover some fairly unlikely uses of the Emboss expression language.
+ def test_unbounded_plus_constant(self):
+ ir = self._make_ir(
+ "external Foo:\n" " [requires: $static_size_in_bits + 2 > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("1", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("2", expr.type.integer.minimum_value)
+ self.assertEqual("infinity", expr.type.integer.maximum_value)
- def test_negative_unbounded_plus_constant(self):
- ir = self._make_ir("external Foo:\n"
- " [requires: -$static_size_in_bits + 2 > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("1", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("-infinity", expr.type.integer.minimum_value)
- self.assertEqual("2", expr.type.integer.maximum_value)
+ def test_negative_unbounded_plus_constant(self):
+ ir = self._make_ir(
+ "external Foo:\n" " [requires: -$static_size_in_bits + 2 > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("1", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("-infinity", expr.type.integer.minimum_value)
+ self.assertEqual("2", expr.type.integer.maximum_value)
- def test_negative_unbounded_plus_unbounded(self):
- ir = self._make_ir(
- "external Foo:\n"
- " [requires: -$static_size_in_bits + $static_size_in_bits > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("1", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("-infinity", expr.type.integer.minimum_value)
- self.assertEqual("infinity", expr.type.integer.maximum_value)
+ def test_negative_unbounded_plus_unbounded(self):
+ ir = self._make_ir(
+ "external Foo:\n"
+ " [requires: -$static_size_in_bits + $static_size_in_bits > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("1", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("-infinity", expr.type.integer.minimum_value)
+ self.assertEqual("infinity", expr.type.integer.maximum_value)
- def test_unbounded_minus_unbounded(self):
- ir = self._make_ir(
- "external Foo:\n"
- " [requires: $static_size_in_bits - $static_size_in_bits > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("1", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("-infinity", expr.type.integer.minimum_value)
- self.assertEqual("infinity", expr.type.integer.maximum_value)
+ def test_unbounded_minus_unbounded(self):
+ ir = self._make_ir(
+ "external Foo:\n"
+ " [requires: $static_size_in_bits - $static_size_in_bits > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("1", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("-infinity", expr.type.integer.minimum_value)
+ self.assertEqual("infinity", expr.type.integer.maximum_value)
- def test_unbounded_minus_negative_unbounded(self):
- ir = self._make_ir(
- "external Foo:\n"
- " [requires: $static_size_in_bits - -$static_size_in_bits > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("1", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("0", expr.type.integer.minimum_value)
- self.assertEqual("infinity", expr.type.integer.maximum_value)
+ def test_unbounded_minus_negative_unbounded(self):
+ ir = self._make_ir(
+ "external Foo:\n"
+ " [requires: $static_size_in_bits - -$static_size_in_bits > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("1", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("0", expr.type.integer.minimum_value)
+ self.assertEqual("infinity", expr.type.integer.maximum_value)
- def test_unbounded_times_constant(self):
- ir = self._make_ir("external Foo:\n"
- " [requires: ($static_size_in_bits + 1) * 2 > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("2", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("2", expr.type.integer.minimum_value)
- self.assertEqual("infinity", expr.type.integer.maximum_value)
+ def test_unbounded_times_constant(self):
+ ir = self._make_ir(
+ "external Foo:\n" " [requires: ($static_size_in_bits + 1) * 2 > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("2", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("2", expr.type.integer.minimum_value)
+ self.assertEqual("infinity", expr.type.integer.maximum_value)
- def test_unbounded_times_negative_constant(self):
- ir = self._make_ir("external Foo:\n"
- " [requires: ($static_size_in_bits + 1) * -2 > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("2", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("-infinity", expr.type.integer.minimum_value)
- self.assertEqual("-2", expr.type.integer.maximum_value)
+ def test_unbounded_times_negative_constant(self):
+ ir = self._make_ir(
+ "external Foo:\n" " [requires: ($static_size_in_bits + 1) * -2 > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("2", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("-infinity", expr.type.integer.minimum_value)
+ self.assertEqual("-2", expr.type.integer.maximum_value)
- def test_unbounded_times_negative_zero(self):
- ir = self._make_ir("external Foo:\n"
- " [requires: ($static_size_in_bits + 1) * 0 > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("infinity", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("0", expr.type.integer.minimum_value)
- self.assertEqual("0", expr.type.integer.maximum_value)
+ def test_unbounded_times_negative_zero(self):
+ ir = self._make_ir(
+ "external Foo:\n" " [requires: ($static_size_in_bits + 1) * 0 > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("infinity", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("0", expr.type.integer.minimum_value)
+ self.assertEqual("0", expr.type.integer.maximum_value)
- def test_negative_unbounded_times_constant(self):
- ir = self._make_ir("external Foo:\n"
- " [requires: (-$static_size_in_bits + 1) * 2 > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("2", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("-infinity", expr.type.integer.minimum_value)
- self.assertEqual("2", expr.type.integer.maximum_value)
+ def test_negative_unbounded_times_constant(self):
+ ir = self._make_ir(
+ "external Foo:\n" " [requires: (-$static_size_in_bits + 1) * 2 > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("2", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("-infinity", expr.type.integer.minimum_value)
+ self.assertEqual("2", expr.type.integer.maximum_value)
- def test_double_unbounded_minus_unbounded(self):
- ir = self._make_ir(
- "external Foo:\n"
- " [requires: 2 * $static_size_in_bits - $static_size_in_bits > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("1", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("-infinity", expr.type.integer.minimum_value)
- self.assertEqual("infinity", expr.type.integer.maximum_value)
+ def test_double_unbounded_minus_unbounded(self):
+ ir = self._make_ir(
+ "external Foo:\n"
+ " [requires: 2 * $static_size_in_bits - $static_size_in_bits > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("1", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("-infinity", expr.type.integer.minimum_value)
+ self.assertEqual("infinity", expr.type.integer.maximum_value)
- def test_double_unbounded_times_negative_unbounded(self):
- ir = self._make_ir(
- "external Foo:\n"
- " [requires: 2 * $static_size_in_bits * -$static_size_in_bits > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("2", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("-infinity", expr.type.integer.minimum_value)
- self.assertEqual("0", expr.type.integer.maximum_value)
+ def test_double_unbounded_times_negative_unbounded(self):
+ ir = self._make_ir(
+ "external Foo:\n"
+ " [requires: 2 * $static_size_in_bits * -$static_size_in_bits > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("2", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("-infinity", expr.type.integer.minimum_value)
+ self.assertEqual("0", expr.type.integer.maximum_value)
- def test_upper_bound_of_field(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] Int x\n"
- " let u = $upper_bound(x)\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- u_type = ir.module[0].type[0].structure.field[1].read_transform.type
- self.assertEqual("infinity", u_type.integer.modulus)
- self.assertEqual("127", u_type.integer.maximum_value)
- self.assertEqual("127", u_type.integer.minimum_value)
- self.assertEqual("127", u_type.integer.modular_value)
+ def test_upper_bound_of_field(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] Int x\n" " let u = $upper_bound(x)\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ u_type = ir.module[0].type[0].structure.field[1].read_transform.type
+ self.assertEqual("infinity", u_type.integer.modulus)
+ self.assertEqual("127", u_type.integer.maximum_value)
+ self.assertEqual("127", u_type.integer.minimum_value)
+ self.assertEqual("127", u_type.integer.modular_value)
- def test_lower_bound_of_field(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] Int x\n"
- " let l = $lower_bound(x)\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- l_type = ir.module[0].type[0].structure.field[1].read_transform.type
- self.assertEqual("infinity", l_type.integer.modulus)
- self.assertEqual("-128", l_type.integer.maximum_value)
- self.assertEqual("-128", l_type.integer.minimum_value)
- self.assertEqual("-128", l_type.integer.modular_value)
+ def test_lower_bound_of_field(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] Int x\n" " let l = $lower_bound(x)\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ l_type = ir.module[0].type[0].structure.field[1].read_transform.type
+ self.assertEqual("infinity", l_type.integer.modulus)
+ self.assertEqual("-128", l_type.integer.maximum_value)
+ self.assertEqual("-128", l_type.integer.minimum_value)
+ self.assertEqual("-128", l_type.integer.modular_value)
- def test_upper_bound_of_max(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] Int x\n"
- " 1 [+1] UInt y\n"
- " let u = $upper_bound($max(x, y))\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- u_type = ir.module[0].type[0].structure.field[2].read_transform.type
- self.assertEqual("infinity", u_type.integer.modulus)
- self.assertEqual("255", u_type.integer.maximum_value)
- self.assertEqual("255", u_type.integer.minimum_value)
- self.assertEqual("255", u_type.integer.modular_value)
+ def test_upper_bound_of_max(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] Int x\n"
+ " 1 [+1] UInt y\n"
+ " let u = $upper_bound($max(x, y))\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ u_type = ir.module[0].type[0].structure.field[2].read_transform.type
+ self.assertEqual("infinity", u_type.integer.modulus)
+ self.assertEqual("255", u_type.integer.maximum_value)
+ self.assertEqual("255", u_type.integer.minimum_value)
+ self.assertEqual("255", u_type.integer.modular_value)
- def test_lower_bound_of_max(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] Int x\n"
- " 1 [+1] UInt y\n"
- " let l = $lower_bound($max(x, y))\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- l_type = ir.module[0].type[0].structure.field[2].read_transform.type
- self.assertEqual("infinity", l_type.integer.modulus)
- self.assertEqual("0", l_type.integer.maximum_value)
- self.assertEqual("0", l_type.integer.minimum_value)
- self.assertEqual("0", l_type.integer.modular_value)
+ def test_lower_bound_of_max(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] Int x\n"
+ " 1 [+1] UInt y\n"
+ " let l = $lower_bound($max(x, y))\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ l_type = ir.module[0].type[0].structure.field[2].read_transform.type
+ self.assertEqual("infinity", l_type.integer.modulus)
+ self.assertEqual("0", l_type.integer.maximum_value)
+ self.assertEqual("0", l_type.integer.minimum_value)
+ self.assertEqual("0", l_type.integer.modular_value)
- def test_double_unbounded_both_ends_times_negative_unbounded(self):
- ir = self._make_ir(
- "external Foo:\n"
- " [requires: (2 * ($static_size_in_bits - $static_size_in_bits) + 1) "
- " * -$static_size_in_bits > 0]\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
- self.assertEqual("1", expr.type.integer.modulus)
- self.assertEqual("0", expr.type.integer.modular_value)
- self.assertEqual("-infinity", expr.type.integer.minimum_value)
- self.assertEqual("infinity", expr.type.integer.maximum_value)
+ def test_double_unbounded_both_ends_times_negative_unbounded(self):
+ ir = self._make_ir(
+ "external Foo:\n"
+ " [requires: (2 * ($static_size_in_bits - $static_size_in_bits) + 1) "
+ " * -$static_size_in_bits > 0]\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].attribute[0].value.expression.function.args[0]
+ self.assertEqual("1", expr.type.integer.modulus)
+ self.assertEqual("0", expr.type.integer.modular_value)
+ self.assertEqual("-infinity", expr.type.integer.minimum_value)
+ self.assertEqual("infinity", expr.type.integer.maximum_value)
- def test_choice_two_non_constant_integers(self):
- cases = [
- # t % 12 == 7 and f % 20 == 15 ==> r % 4 == 3
- (12, 7, 20, 15, 4, 3, -128 * 20 + 15, 127 * 20 + 15),
- # t % 24 == 15 and f % 12 == 7 ==> r % 4 == 3
- (24, 15, 12, 7, 4, 3, -128 * 24 + 15, 127 * 24 + 15),
- # t % 20 == 15 and f % 20 == 10 ==> r % 5 == 0
- (20, 15, 20, 10, 5, 0, -128 * 20 + 10, 127 * 20 + 15),
- # t % 20 == 16 and f % 20 == 11 ==> r % 5 == 1
- (20, 16, 20, 11, 5, 1, -128 * 20 + 11, 127 * 20 + 16),
- ]
- for (t_mod, t_val, f_mod, f_val, r_mod, r_val, r_min, r_max) in cases:
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] Int y\n"
- " if (x == 0 ? y * {} + {} : y * {} + {}) == 0:\n"
- " 1 [+1] UInt z\n".format(
- t_mod, t_val, f_mod, f_val))
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field = ir.module[0].type[0].structure.field[2]
- expr = field.existence_condition.function.args[0]
- self.assertEqual(str(r_mod), expr.type.integer.modulus)
- self.assertEqual(str(r_val), expr.type.integer.modular_value)
- self.assertEqual(str(r_min), expr.type.integer.minimum_value)
- self.assertEqual(str(r_max), expr.type.integer.maximum_value)
+ def test_choice_two_non_constant_integers(self):
+ cases = [
+ # t % 12 == 7 and f % 20 == 15 ==> r % 4 == 3
+ (12, 7, 20, 15, 4, 3, -128 * 20 + 15, 127 * 20 + 15),
+ # t % 24 == 15 and f % 12 == 7 ==> r % 4 == 3
+ (24, 15, 12, 7, 4, 3, -128 * 24 + 15, 127 * 24 + 15),
+ # t % 20 == 15 and f % 20 == 10 ==> r % 5 == 0
+ (20, 15, 20, 10, 5, 0, -128 * 20 + 10, 127 * 20 + 15),
+ # t % 20 == 16 and f % 20 == 11 ==> r % 5 == 1
+ (20, 16, 20, 11, 5, 1, -128 * 20 + 11, 127 * 20 + 16),
+ ]
+ for t_mod, t_val, f_mod, f_val, r_mod, r_val, r_min, r_max in cases:
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] Int y\n"
+ " if (x == 0 ? y * {} + {} : y * {} + {}) == 0:\n"
+ " 1 [+1] UInt z\n".format(t_mod, t_val, f_mod, f_val)
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field = ir.module[0].type[0].structure.field[2]
+ expr = field.existence_condition.function.args[0]
+ self.assertEqual(str(r_mod), expr.type.integer.modulus)
+ self.assertEqual(str(r_val), expr.type.integer.modular_value)
+ self.assertEqual(str(r_min), expr.type.integer.minimum_value)
+ self.assertEqual(str(r_max), expr.type.integer.maximum_value)
- def test_choice_one_non_constant_integer(self):
- cases = [
- # t == 35 and f % 20 == 15 ==> res % 20 == 15
- (35, 20, 15, 20, 15, -128 * 20 + 15, 127 * 20 + 15),
- # t == 200035 and f % 20 == 15 ==> res % 20 == 15
- (200035, 20, 15, 20, 15, -128 * 20 + 15, 200035),
- # t == 21 and f % 20 == 16 ==> res % 5 == 1
- (21, 20, 16, 5, 1, -128 * 20 + 16, 127 * 20 + 16),
- ]
- for (t_val, f_mod, f_val, r_mod, r_val, r_min, r_max) in cases:
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] Int y\n"
- " if (x == 0 ? {0} : y * {1} + {2}) == 0:\n"
- " 1 [+1] UInt z\n"
- " if (x == 0 ? y * {1} + {2} : {0}) == 0:\n"
- " 1 [+1] UInt q\n".format(t_val, f_mod, f_val))
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field_constant_true = ir.module[0].type[0].structure.field[2]
- constant_true = field_constant_true.existence_condition.function.args[0]
- field_constant_false = ir.module[0].type[0].structure.field[3]
- constant_false = field_constant_false.existence_condition.function.args[0]
- self.assertEqual(str(r_mod), constant_true.type.integer.modulus)
- self.assertEqual(str(r_val), constant_true.type.integer.modular_value)
- self.assertEqual(str(r_min), constant_true.type.integer.minimum_value)
- self.assertEqual(str(r_max), constant_true.type.integer.maximum_value)
- self.assertEqual(str(r_mod), constant_false.type.integer.modulus)
- self.assertEqual(str(r_val), constant_false.type.integer.modular_value)
- self.assertEqual(str(r_min), constant_false.type.integer.minimum_value)
- self.assertEqual(str(r_max), constant_false.type.integer.maximum_value)
+ def test_choice_one_non_constant_integer(self):
+ cases = [
+ # t == 35 and f % 20 == 15 ==> res % 20 == 15
+ (35, 20, 15, 20, 15, -128 * 20 + 15, 127 * 20 + 15),
+ # t == 200035 and f % 20 == 15 ==> res % 20 == 15
+ (200035, 20, 15, 20, 15, -128 * 20 + 15, 200035),
+ # t == 21 and f % 20 == 16 ==> res % 5 == 1
+ (21, 20, 16, 5, 1, -128 * 20 + 16, 127 * 20 + 16),
+ ]
+ for t_val, f_mod, f_val, r_mod, r_val, r_min, r_max in cases:
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] Int y\n"
+ " if (x == 0 ? {0} : y * {1} + {2}) == 0:\n"
+ " 1 [+1] UInt z\n"
+ " if (x == 0 ? y * {1} + {2} : {0}) == 0:\n"
+ " 1 [+1] UInt q\n".format(t_val, f_mod, f_val)
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field_constant_true = ir.module[0].type[0].structure.field[2]
+ constant_true = field_constant_true.existence_condition.function.args[0]
+ field_constant_false = ir.module[0].type[0].structure.field[3]
+ constant_false = field_constant_false.existence_condition.function.args[0]
+ self.assertEqual(str(r_mod), constant_true.type.integer.modulus)
+ self.assertEqual(str(r_val), constant_true.type.integer.modular_value)
+ self.assertEqual(str(r_min), constant_true.type.integer.minimum_value)
+ self.assertEqual(str(r_max), constant_true.type.integer.maximum_value)
+ self.assertEqual(str(r_mod), constant_false.type.integer.modulus)
+ self.assertEqual(str(r_val), constant_false.type.integer.modular_value)
+ self.assertEqual(str(r_min), constant_false.type.integer.minimum_value)
+ self.assertEqual(str(r_max), constant_false.type.integer.maximum_value)
- def test_choice_two_constant_integers(self):
- cases = [
- # t == 10 and f == 7 ==> res % 3 == 1
- (10, 7, 3, 1, 7, 10),
- # t == 4 and f == 4 ==> res == 4
- (4, 4, "infinity", 4, 4, 4),
- ]
- for (t_val, f_val, r_mod, r_val, r_min, r_max) in cases:
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] Int y\n"
- " if (x == 0 ? {} : {}) == 0:\n"
- " 1 [+1] UInt z\n".format(t_val, f_val))
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field_constant_true = ir.module[0].type[0].structure.field[2]
- constant_true = field_constant_true.existence_condition.function.args[0]
- self.assertEqual(str(r_mod), constant_true.type.integer.modulus)
- self.assertEqual(str(r_val), constant_true.type.integer.modular_value)
- self.assertEqual(str(r_min), constant_true.type.integer.minimum_value)
- self.assertEqual(str(r_max), constant_true.type.integer.maximum_value)
+ def test_choice_two_constant_integers(self):
+ cases = [
+ # t == 10 and f == 7 ==> res % 3 == 1
+ (10, 7, 3, 1, 7, 10),
+ # t == 4 and f == 4 ==> res == 4
+ (4, 4, "infinity", 4, 4, 4),
+ ]
+ for t_val, f_val, r_mod, r_val, r_min, r_max in cases:
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] Int y\n"
+ " if (x == 0 ? {} : {}) == 0:\n"
+ " 1 [+1] UInt z\n".format(t_val, f_val)
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field_constant_true = ir.module[0].type[0].structure.field[2]
+ constant_true = field_constant_true.existence_condition.function.args[0]
+ self.assertEqual(str(r_mod), constant_true.type.integer.modulus)
+ self.assertEqual(str(r_val), constant_true.type.integer.modular_value)
+ self.assertEqual(str(r_min), constant_true.type.integer.minimum_value)
+ self.assertEqual(str(r_max), constant_true.type.integer.maximum_value)
- def test_constant_true_has(self):
- ir = self._make_ir("struct Foo:\n"
- " if $present(x):\n"
- " 1 [+1] UInt q\n"
- " 0 [+1] UInt x\n"
- " if x > 10:\n"
- " 1 [+1] Int y\n"
- " if false:\n"
- " 2 [+1] Int z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field = ir.module[0].type[0].structure.field[0]
- has_func = field.existence_condition
- self.assertTrue(has_func.type.boolean.value)
+ def test_constant_true_has(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if $present(x):\n"
+ " 1 [+1] UInt q\n"
+ " 0 [+1] UInt x\n"
+ " if x > 10:\n"
+ " 1 [+1] Int y\n"
+ " if false:\n"
+ " 2 [+1] Int z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field = ir.module[0].type[0].structure.field[0]
+ has_func = field.existence_condition
+ self.assertTrue(has_func.type.boolean.value)
- def test_constant_false_has(self):
- ir = self._make_ir("struct Foo:\n"
- " if $present(z):\n"
- " 1 [+1] UInt q\n"
- " 0 [+1] UInt x\n"
- " if x > 10:\n"
- " 1 [+1] Int y\n"
- " if false:\n"
- " 2 [+1] Int z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field = ir.module[0].type[0].structure.field[0]
- has_func = field.existence_condition
- self.assertTrue(has_func.type.boolean.HasField("value"))
- self.assertFalse(has_func.type.boolean.value)
+ def test_constant_false_has(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if $present(z):\n"
+ " 1 [+1] UInt q\n"
+ " 0 [+1] UInt x\n"
+ " if x > 10:\n"
+ " 1 [+1] Int y\n"
+ " if false:\n"
+ " 2 [+1] Int z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field = ir.module[0].type[0].structure.field[0]
+ has_func = field.existence_condition
+ self.assertTrue(has_func.type.boolean.HasField("value"))
+ self.assertFalse(has_func.type.boolean.value)
- def test_variable_has(self):
- ir = self._make_ir("struct Foo:\n"
- " if $present(y):\n"
- " 1 [+1] UInt q\n"
- " 0 [+1] UInt x\n"
- " if x > 10:\n"
- " 1 [+1] Int y\n"
- " if false:\n"
- " 2 [+1] Int z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field = ir.module[0].type[0].structure.field[0]
- has_func = field.existence_condition
- self.assertFalse(has_func.type.boolean.HasField("value"))
+ def test_variable_has(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if $present(y):\n"
+ " 1 [+1] UInt q\n"
+ " 0 [+1] UInt x\n"
+ " if x > 10:\n"
+ " 1 [+1] Int y\n"
+ " if false:\n"
+ " 2 [+1] Int z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field = ir.module[0].type[0].structure.field[0]
+ has_func = field.existence_condition
+ self.assertFalse(has_func.type.boolean.HasField("value"))
- def test_max_of_constants(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] Int y\n"
- " if $max(0, 1, 2) == 0:\n"
- " 1 [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field = ir.module[0].type[0].structure.field[2]
- max_func = field.existence_condition.function.args[0]
- self.assertEqual("infinity", max_func.type.integer.modulus)
- self.assertEqual("2", max_func.type.integer.modular_value)
- self.assertEqual("2", max_func.type.integer.minimum_value)
- self.assertEqual("2", max_func.type.integer.maximum_value)
+ def test_max_of_constants(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] Int y\n"
+ " if $max(0, 1, 2) == 0:\n"
+ " 1 [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field = ir.module[0].type[0].structure.field[2]
+ max_func = field.existence_condition.function.args[0]
+ self.assertEqual("infinity", max_func.type.integer.modulus)
+ self.assertEqual("2", max_func.type.integer.modular_value)
+ self.assertEqual("2", max_func.type.integer.minimum_value)
+ self.assertEqual("2", max_func.type.integer.maximum_value)
- def test_max_dominated_by_constant(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] Int y\n"
- " if $max(x, y, 255) == 0:\n"
- " 1 [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field = ir.module[0].type[0].structure.field[2]
- max_func = field.existence_condition.function.args[0]
- self.assertEqual("infinity", max_func.type.integer.modulus)
- self.assertEqual("255", max_func.type.integer.modular_value)
- self.assertEqual("255", max_func.type.integer.minimum_value)
- self.assertEqual("255", max_func.type.integer.maximum_value)
+ def test_max_dominated_by_constant(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] Int y\n"
+ " if $max(x, y, 255) == 0:\n"
+ " 1 [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field = ir.module[0].type[0].structure.field[2]
+ max_func = field.existence_condition.function.args[0]
+ self.assertEqual("infinity", max_func.type.integer.modulus)
+ self.assertEqual("255", max_func.type.integer.modular_value)
+ self.assertEqual("255", max_func.type.integer.minimum_value)
+ self.assertEqual("255", max_func.type.integer.maximum_value)
- def test_max_of_variables(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] Int y\n"
- " if $max(x, y) == 0:\n"
- " 1 [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field = ir.module[0].type[0].structure.field[2]
- max_func = field.existence_condition.function.args[0]
- self.assertEqual("1", max_func.type.integer.modulus)
- self.assertEqual("0", max_func.type.integer.modular_value)
- self.assertEqual("0", max_func.type.integer.minimum_value)
- self.assertEqual("255", max_func.type.integer.maximum_value)
+ def test_max_of_variables(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] Int y\n"
+ " if $max(x, y) == 0:\n"
+ " 1 [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field = ir.module[0].type[0].structure.field[2]
+ max_func = field.existence_condition.function.args[0]
+ self.assertEqual("1", max_func.type.integer.modulus)
+ self.assertEqual("0", max_func.type.integer.modular_value)
+ self.assertEqual("0", max_func.type.integer.minimum_value)
+ self.assertEqual("255", max_func.type.integer.maximum_value)
- def test_max_of_variables_with_shared_modulus(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] Int y\n"
- " if $max(x * 8 + 5, y * 4 + 3) == 0:\n"
- " 1 [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field = ir.module[0].type[0].structure.field[2]
- max_func = field.existence_condition.function.args[0]
- self.assertEqual("2", max_func.type.integer.modulus)
- self.assertEqual("1", max_func.type.integer.modular_value)
- self.assertEqual("5", max_func.type.integer.minimum_value)
- self.assertEqual("2045", max_func.type.integer.maximum_value)
+ def test_max_of_variables_with_shared_modulus(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] Int y\n"
+ " if $max(x * 8 + 5, y * 4 + 3) == 0:\n"
+ " 1 [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field = ir.module[0].type[0].structure.field[2]
+ max_func = field.existence_condition.function.args[0]
+ self.assertEqual("2", max_func.type.integer.modulus)
+ self.assertEqual("1", max_func.type.integer.modular_value)
+ self.assertEqual("5", max_func.type.integer.minimum_value)
+ self.assertEqual("2045", max_func.type.integer.maximum_value)
- def test_max_of_three_variables(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] Int y\n"
- " 2 [+2] Int z\n"
- " if $max(x, y, z) == 0:\n"
- " 1 [+1] UInt q\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field = ir.module[0].type[0].structure.field[3]
- max_func = field.existence_condition.function.args[0]
- self.assertEqual("1", max_func.type.integer.modulus)
- self.assertEqual("0", max_func.type.integer.modular_value)
- self.assertEqual("0", max_func.type.integer.minimum_value)
- self.assertEqual("32767", max_func.type.integer.maximum_value)
+ def test_max_of_three_variables(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] Int y\n"
+ " 2 [+2] Int z\n"
+ " if $max(x, y, z) == 0:\n"
+ " 1 [+1] UInt q\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field = ir.module[0].type[0].structure.field[3]
+ max_func = field.existence_condition.function.args[0]
+ self.assertEqual("1", max_func.type.integer.modulus)
+ self.assertEqual("0", max_func.type.integer.modular_value)
+ self.assertEqual("0", max_func.type.integer.minimum_value)
+ self.assertEqual("32767", max_func.type.integer.maximum_value)
- def test_max_of_one_variable(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] Int y\n"
- " 2 [+2] Int z\n"
- " if $max(x * 2 + 3) == 0:\n"
- " 1 [+1] UInt q\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field = ir.module[0].type[0].structure.field[3]
- max_func = field.existence_condition.function.args[0]
- self.assertEqual("2", max_func.type.integer.modulus)
- self.assertEqual("1", max_func.type.integer.modular_value)
- self.assertEqual("3", max_func.type.integer.minimum_value)
- self.assertEqual("513", max_func.type.integer.maximum_value)
+ def test_max_of_one_variable(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] Int y\n"
+ " 2 [+2] Int z\n"
+ " if $max(x * 2 + 3) == 0:\n"
+ " 1 [+1] UInt q\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field = ir.module[0].type[0].structure.field[3]
+ max_func = field.existence_condition.function.args[0]
+ self.assertEqual("2", max_func.type.integer.modulus)
+ self.assertEqual("1", max_func.type.integer.modular_value)
+ self.assertEqual("3", max_func.type.integer.minimum_value)
+ self.assertEqual("513", max_func.type.integer.maximum_value)
- def test_max_of_one_variable_and_one_constant(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] Int y\n"
- " 2 [+2] Int z\n"
- " if $max(x * 2 + 3, 311) == 0:\n"
- " 1 [+1] UInt q\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field = ir.module[0].type[0].structure.field[3]
- max_func = field.existence_condition.function.args[0]
- self.assertEqual("2", max_func.type.integer.modulus)
- self.assertEqual("1", max_func.type.integer.modular_value)
- self.assertEqual("311", max_func.type.integer.minimum_value)
- self.assertEqual("513", max_func.type.integer.maximum_value)
+ def test_max_of_one_variable_and_one_constant(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] Int y\n"
+ " 2 [+2] Int z\n"
+ " if $max(x * 2 + 3, 311) == 0:\n"
+ " 1 [+1] UInt q\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field = ir.module[0].type[0].structure.field[3]
+ max_func = field.existence_condition.function.args[0]
+ self.assertEqual("2", max_func.type.integer.modulus)
+ self.assertEqual("1", max_func.type.integer.modular_value)
+ self.assertEqual("311", max_func.type.integer.minimum_value)
+ self.assertEqual("513", max_func.type.integer.maximum_value)
- def test_choice_non_integer_arguments(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " if x == 0 ? false : true:\n"
- " 1 [+1] UInt y\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- expr = ir.module[0].type[0].structure.field[1].existence_condition
- self.assertEqual("boolean", expr.type.WhichOneof("type"))
- self.assertFalse(expr.type.boolean.HasField("value"))
+ def test_choice_non_integer_arguments(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " if x == 0 ? false : true:\n"
+ " 1 [+1] UInt y\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ expr = ir.module[0].type[0].structure.field[1].existence_condition
+ self.assertEqual("boolean", expr.type.WhichOneof("type"))
+ self.assertFalse(expr.type.boolean.HasField("value"))
- def test_uint_value_range_for_explicit_size(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+x] UInt:16 y\n"
- " y [+1] UInt z\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("1", z_start.type.integer.modulus)
- self.assertEqual("0", z_start.type.integer.modular_value)
- self.assertEqual("0", z_start.type.integer.minimum_value)
- self.assertEqual("65535", z_start.type.integer.maximum_value)
+ def test_uint_value_range_for_explicit_size(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+x] UInt:16 y\n"
+ " y [+1] UInt z\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("1", z_start.type.integer.modulus)
+ self.assertEqual("0", z_start.type.integer.modular_value)
+ self.assertEqual("0", z_start.type.integer.minimum_value)
+ self.assertEqual("65535", z_start.type.integer.maximum_value)
- def test_uint_value_ranges(self):
- cases = [
- (1, 1),
- (2, 3),
- (3, 7),
- (4, 15),
- (8, 255),
- (12, 4095),
- (15, 32767),
- (16, 65535),
- (32, 4294967295),
- (48, 281474976710655),
- (64, 18446744073709551615),
- ]
- for bits, upper in cases:
- ir = self._make_ir("struct Foo:\n"
- " 0 [+8] bits:\n"
- " 0 [+{}] UInt x\n"
- " x [+1] UInt z\n".format(bits))
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("1", z_start.type.integer.modulus)
- self.assertEqual("0", z_start.type.integer.modular_value)
- self.assertEqual("0", z_start.type.integer.minimum_value)
- self.assertEqual(str(upper), z_start.type.integer.maximum_value)
+ def test_uint_value_ranges(self):
+ cases = [
+ (1, 1),
+ (2, 3),
+ (3, 7),
+ (4, 15),
+ (8, 255),
+ (12, 4095),
+ (15, 32767),
+ (16, 65535),
+ (32, 4294967295),
+ (48, 281474976710655),
+ (64, 18446744073709551615),
+ ]
+ for bits, upper in cases:
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+8] bits:\n"
+ " 0 [+{}] UInt x\n"
+ " x [+1] UInt z\n".format(bits)
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("1", z_start.type.integer.modulus)
+ self.assertEqual("0", z_start.type.integer.modular_value)
+ self.assertEqual("0", z_start.type.integer.minimum_value)
+ self.assertEqual(str(upper), z_start.type.integer.maximum_value)
- def test_int_value_ranges(self):
- cases = [
- (1, -1, 0),
- (2, -2, 1),
- (3, -4, 3),
- (4, -8, 7),
- (8, -128, 127),
- (12, -2048, 2047),
- (15, -16384, 16383),
- (16, -32768, 32767),
- (32, -2147483648, 2147483647),
- (48, -140737488355328, 140737488355327),
- (64, -9223372036854775808, 9223372036854775807),
- ]
- for bits, lower, upper in cases:
- ir = self._make_ir("struct Foo:\n"
- " 0 [+8] bits:\n"
- " 0 [+{}] Int x\n"
- " x [+1] UInt z\n".format(bits))
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("1", z_start.type.integer.modulus)
- self.assertEqual("0", z_start.type.integer.modular_value)
- self.assertEqual(str(lower), z_start.type.integer.minimum_value)
- self.assertEqual(str(upper), z_start.type.integer.maximum_value)
+ def test_int_value_ranges(self):
+ cases = [
+ (1, -1, 0),
+ (2, -2, 1),
+ (3, -4, 3),
+ (4, -8, 7),
+ (8, -128, 127),
+ (12, -2048, 2047),
+ (15, -16384, 16383),
+ (16, -32768, 32767),
+ (32, -2147483648, 2147483647),
+ (48, -140737488355328, 140737488355327),
+ (64, -9223372036854775808, 9223372036854775807),
+ ]
+ for bits, lower, upper in cases:
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+8] bits:\n"
+ " 0 [+{}] Int x\n"
+ " x [+1] UInt z\n".format(bits)
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("1", z_start.type.integer.modulus)
+ self.assertEqual("0", z_start.type.integer.modular_value)
+ self.assertEqual(str(lower), z_start.type.integer.minimum_value)
+ self.assertEqual(str(upper), z_start.type.integer.maximum_value)
- def test_bcd_value_ranges(self):
- cases = [
- (1, 1),
- (2, 3),
- (3, 7),
- (4, 9),
- (8, 99),
- (12, 999),
- (15, 7999),
- (16, 9999),
- (32, 99999999),
- (48, 999999999999),
- (64, 9999999999999999),
- ]
- for bits, upper in cases:
- ir = self._make_ir("struct Foo:\n"
- " 0 [+8] bits:\n"
- " 0 [+{}] Bcd x\n"
- " x [+1] UInt z\n".format(bits))
- self.assertEqual([], expression_bounds.compute_constants(ir))
- z_start = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual("1", z_start.type.integer.modulus)
- self.assertEqual("0", z_start.type.integer.modular_value)
- self.assertEqual("0", z_start.type.integer.minimum_value)
- self.assertEqual(str(upper), z_start.type.integer.maximum_value)
+ def test_bcd_value_ranges(self):
+ cases = [
+ (1, 1),
+ (2, 3),
+ (3, 7),
+ (4, 9),
+ (8, 99),
+ (12, 999),
+ (15, 7999),
+ (16, 9999),
+ (32, 99999999),
+ (48, 999999999999),
+ (64, 9999999999999999),
+ ]
+ for bits, upper in cases:
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+8] bits:\n"
+ " 0 [+{}] Bcd x\n"
+ " x [+1] UInt z\n".format(bits)
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ z_start = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual("1", z_start.type.integer.modulus)
+ self.assertEqual("0", z_start.type.integer.modular_value)
+ self.assertEqual("0", z_start.type.integer.minimum_value)
+ self.assertEqual(str(upper), z_start.type.integer.maximum_value)
- def test_virtual_field_bounds(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " let y = x + 10\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field_y = ir.module[0].type[0].structure.field[1]
- self.assertEqual("1", field_y.read_transform.type.integer.modulus)
- self.assertEqual("0", field_y.read_transform.type.integer.modular_value)
- self.assertEqual("10", field_y.read_transform.type.integer.minimum_value)
- self.assertEqual("265", field_y.read_transform.type.integer.maximum_value)
+ def test_virtual_field_bounds(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " let y = x + 10\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field_y = ir.module[0].type[0].structure.field[1]
+ self.assertEqual("1", field_y.read_transform.type.integer.modulus)
+ self.assertEqual("0", field_y.read_transform.type.integer.modular_value)
+ self.assertEqual("10", field_y.read_transform.type.integer.minimum_value)
+ self.assertEqual("265", field_y.read_transform.type.integer.maximum_value)
- def test_virtual_field_bounds_copied(self):
- ir = self._make_ir("struct Foo:\n"
- " let z = y + 100\n"
- " let y = x + 10\n"
- " 0 [+1] UInt x\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field_z = ir.module[0].type[0].structure.field[0]
- self.assertEqual("1", field_z.read_transform.type.integer.modulus)
- self.assertEqual("0", field_z.read_transform.type.integer.modular_value)
- self.assertEqual("110", field_z.read_transform.type.integer.minimum_value)
- self.assertEqual("365", field_z.read_transform.type.integer.maximum_value)
- y_reference = field_z.read_transform.function.args[0]
- self.assertEqual("1", y_reference.type.integer.modulus)
- self.assertEqual("0", y_reference.type.integer.modular_value)
- self.assertEqual("10", y_reference.type.integer.minimum_value)
- self.assertEqual("265", y_reference.type.integer.maximum_value)
+ def test_virtual_field_bounds_copied(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " let z = y + 100\n"
+ " let y = x + 10\n"
+ " 0 [+1] UInt x\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field_z = ir.module[0].type[0].structure.field[0]
+ self.assertEqual("1", field_z.read_transform.type.integer.modulus)
+ self.assertEqual("0", field_z.read_transform.type.integer.modular_value)
+ self.assertEqual("110", field_z.read_transform.type.integer.minimum_value)
+ self.assertEqual("365", field_z.read_transform.type.integer.maximum_value)
+ y_reference = field_z.read_transform.function.args[0]
+ self.assertEqual("1", y_reference.type.integer.modulus)
+ self.assertEqual("0", y_reference.type.integer.modular_value)
+ self.assertEqual("10", y_reference.type.integer.minimum_value)
+ self.assertEqual("265", y_reference.type.integer.maximum_value)
- def test_constant_reference_to_virtual_bounds_copied(self):
- ir = self._make_ir("struct Foo:\n"
- " let ten = Bar.ten\n"
- " let truth = Bar.truth\n"
- "struct Bar:\n"
- " let ten = 10\n"
- " let truth = true\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field_ten = ir.module[0].type[0].structure.field[0]
- self.assertEqual("infinity", field_ten.read_transform.type.integer.modulus)
- self.assertEqual("10", field_ten.read_transform.type.integer.modular_value)
- self.assertEqual("10", field_ten.read_transform.type.integer.minimum_value)
- self.assertEqual("10", field_ten.read_transform.type.integer.maximum_value)
- field_truth = ir.module[0].type[0].structure.field[1]
- self.assertTrue(field_truth.read_transform.type.boolean.value)
+ def test_constant_reference_to_virtual_bounds_copied(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " let ten = Bar.ten\n"
+ " let truth = Bar.truth\n"
+ "struct Bar:\n"
+ " let ten = 10\n"
+ " let truth = true\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field_ten = ir.module[0].type[0].structure.field[0]
+ self.assertEqual("infinity", field_ten.read_transform.type.integer.modulus)
+ self.assertEqual("10", field_ten.read_transform.type.integer.modular_value)
+ self.assertEqual("10", field_ten.read_transform.type.integer.minimum_value)
+ self.assertEqual("10", field_ten.read_transform.type.integer.maximum_value)
+ field_truth = ir.module[0].type[0].structure.field[1]
+ self.assertTrue(field_truth.read_transform.type.boolean.value)
- def test_forward_reference_to_reference_to_enum_correctly_calculated(self):
- ir = self._make_ir("struct Foo:\n"
- " let ten = Bar.TEN\n"
- "enum Bar:\n"
- " TEN = TEN2\n"
- " TEN2 = 5 + 5\n")
- self.assertEqual([], expression_bounds.compute_constants(ir))
- field_ten = ir.module[0].type[0].structure.field[0]
- self.assertEqual("10", field_ten.read_transform.type.enumeration.value)
+ def test_forward_reference_to_reference_to_enum_correctly_calculated(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " let ten = Bar.TEN\n"
+ "enum Bar:\n"
+ " TEN = TEN2\n"
+ " TEN2 = 5 + 5\n"
+ )
+ self.assertEqual([], expression_bounds.compute_constants(ir))
+ field_ten = ir.module[0].type[0].structure.field[0]
+ self.assertEqual("10", field_ten.read_transform.type.enumeration.value)
class InfinityAugmentedArithmeticTest(unittest.TestCase):
- # TODO(bolms): Will there ever be any situations where all elements of the arg
- # to _min would be "infinity"?
- def test_min_of_infinities(self):
- self.assertEqual("infinity",
- expression_bounds._min(["infinity", "infinity"]))
+ # TODO(bolms): Will there ever be any situations where all elements of the arg
+ # to _min would be "infinity"?
+ def test_min_of_infinities(self):
+ self.assertEqual("infinity", expression_bounds._min(["infinity", "infinity"]))
- # TODO(bolms): Will there ever be any situations where all elements of the arg
- # to _max would be "-infinity"?
- def test_max_of_negative_infinities(self):
- self.assertEqual("-infinity",
- expression_bounds._max(["-infinity", "-infinity"]))
+ # TODO(bolms): Will there ever be any situations where all elements of the arg
+ # to _max would be "-infinity"?
+ def test_max_of_negative_infinities(self):
+ self.assertEqual(
+ "-infinity", expression_bounds._max(["-infinity", "-infinity"])
+ )
- def test_shared_modular_value_of_identical_modulus_and_value(self):
- self.assertEqual((10, 8),
- expression_bounds._shared_modular_value((10, 8), (10, 8)))
+ def test_shared_modular_value_of_identical_modulus_and_value(self):
+ self.assertEqual(
+ (10, 8), expression_bounds._shared_modular_value((10, 8), (10, 8))
+ )
- def test_shared_modular_value_of_identical_modulus(self):
- self.assertEqual((5, 3),
- expression_bounds._shared_modular_value((10, 8), (10, 3)))
+ def test_shared_modular_value_of_identical_modulus(self):
+ self.assertEqual(
+ (5, 3), expression_bounds._shared_modular_value((10, 8), (10, 3))
+ )
- def test_shared_modular_value_of_identical_value(self):
- self.assertEqual((6, 2),
- expression_bounds._shared_modular_value((18, 2), (12, 2)))
+ def test_shared_modular_value_of_identical_value(self):
+ self.assertEqual(
+ (6, 2), expression_bounds._shared_modular_value((18, 2), (12, 2))
+ )
- def test_shared_modular_value_of_different_arguments(self):
- self.assertEqual((7, 4),
- expression_bounds._shared_modular_value((21, 11), (14, 4)))
+ def test_shared_modular_value_of_different_arguments(self):
+ self.assertEqual(
+ (7, 4), expression_bounds._shared_modular_value((21, 11), (14, 4))
+ )
- def test_shared_modular_value_of_infinity_and_non(self):
- self.assertEqual((7, 4),
- expression_bounds._shared_modular_value(("infinity", 25),
- (14, 4)))
+ def test_shared_modular_value_of_infinity_and_non(self):
+ self.assertEqual(
+ (7, 4), expression_bounds._shared_modular_value(("infinity", 25), (14, 4))
+ )
- def test_shared_modular_value_of_infinity_and_infinity(self):
- self.assertEqual((14, 5),
- expression_bounds._shared_modular_value(("infinity", 19),
- ("infinity", 5)))
+ def test_shared_modular_value_of_infinity_and_infinity(self):
+ self.assertEqual(
+ (14, 5),
+ expression_bounds._shared_modular_value(("infinity", 19), ("infinity", 5)),
+ )
- def test_shared_modular_value_of_infinity_and_identical_value(self):
- self.assertEqual(("infinity", 5),
- expression_bounds._shared_modular_value(("infinity", 5),
- ("infinity", 5)))
+ def test_shared_modular_value_of_infinity_and_identical_value(self):
+ self.assertEqual(
+ ("infinity", 5),
+ expression_bounds._shared_modular_value(("infinity", 5), ("infinity", 5)),
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/format.py b/compiler/front_end/format.py
index 1e2389b..1b7de31 100644
--- a/compiler/front_end/format.py
+++ b/compiler/front_end/format.py
@@ -31,99 +31,112 @@
def _parse_command_line(argv):
- """Parses the given command-line arguments."""
- argparser = argparse.ArgumentParser(description='Emboss compiler front end.',
- prog=argv[0])
- argparser.add_argument('input_file',
- type=str,
- nargs='+',
- help='.emb file to compile.')
- argparser.add_argument('--no-check-result',
- default=True,
- action='store_false',
- dest='check_result',
- help='Verify that the resulting formatted text '
- 'contains only whitespace changes.')
- argparser.add_argument('--debug-show-line-types',
- default=False,
- help='Show the computed type of each line.')
- argparser.add_argument('--no-edit-in-place',
- default=True,
- action='store_false',
- dest='edit_in_place',
- help='Write the formatted text back to the input '
- 'file.')
- argparser.add_argument('--indent',
- type=int,
- default=2,
- help='Number of spaces to use for each level of '
- 'indentation.')
- argparser.add_argument('--color-output',
- default='if-tty',
- choices=['always', 'never', 'if-tty', 'auto'],
- help="Print error messages using color. 'auto' is a "
- "synonym for 'if-tty'.")
- return argparser.parse_args(argv[1:])
+ """Parses the given command-line arguments."""
+ argparser = argparse.ArgumentParser(
+ description="Emboss compiler front end.", prog=argv[0]
+ )
+ argparser.add_argument(
+ "input_file", type=str, nargs="+", help=".emb file to compile."
+ )
+ argparser.add_argument(
+ "--no-check-result",
+ default=True,
+ action="store_false",
+ dest="check_result",
+ help="Verify that the resulting formatted text "
+ "contains only whitespace changes.",
+ )
+ argparser.add_argument(
+ "--debug-show-line-types",
+ default=False,
+ help="Show the computed type of each line.",
+ )
+ argparser.add_argument(
+ "--no-edit-in-place",
+ default=True,
+ action="store_false",
+ dest="edit_in_place",
+ help="Write the formatted text back to the input " "file.",
+ )
+ argparser.add_argument(
+ "--indent",
+ type=int,
+ default=2,
+ help="Number of spaces to use for each level of " "indentation.",
+ )
+ argparser.add_argument(
+ "--color-output",
+ default="if-tty",
+ choices=["always", "never", "if-tty", "auto"],
+ help="Print error messages using color. 'auto' is a " "synonym for 'if-tty'.",
+ )
+ return argparser.parse_args(argv[1:])
def _print_errors(errors, source_codes, flags):
- use_color = (flags.color_output == 'always' or
- (flags.color_output in ('auto', 'if-tty') and
- os.isatty(sys.stderr.fileno())))
- print(error.format_errors(errors, source_codes, use_color), file=sys.stderr)
+ use_color = flags.color_output == "always" or (
+ flags.color_output in ("auto", "if-tty") and os.isatty(sys.stderr.fileno())
+ )
+ print(error.format_errors(errors, source_codes, use_color), file=sys.stderr)
def main(argv=()):
- flags = _parse_command_line(argv)
+ flags = _parse_command_line(argv)
- if not flags.edit_in_place and len(flags.input_file) > 1:
- print('Multiple files may only be formatted without --no-edit-in-place.',
- file=sys.stderr)
- return 1
+ if not flags.edit_in_place and len(flags.input_file) > 1:
+ print(
+ "Multiple files may only be formatted without --no-edit-in-place.",
+ file=sys.stderr,
+ )
+ return 1
- if flags.edit_in_place and flags.debug_show_line_types:
- print('The flag --debug-show-line-types requires --no-edit-in-place.',
- file=sys.stderr)
- return 1
+ if flags.edit_in_place and flags.debug_show_line_types:
+ print(
+ "The flag --debug-show-line-types requires --no-edit-in-place.",
+ file=sys.stderr,
+ )
+ return 1
- for file_name in flags.input_file:
- with open(file_name) as f:
- source_code = f.read()
+ for file_name in flags.input_file:
+ with open(file_name) as f:
+ source_code = f.read()
- tokens, errors = tokenizer.tokenize(source_code, file_name)
- if errors:
- _print_errors(errors, {file_name: source_code}, flags)
- continue
+ tokens, errors = tokenizer.tokenize(source_code, file_name)
+ if errors:
+ _print_errors(errors, {file_name: source_code}, flags)
+ continue
- parse_result = parser.parse_module(tokens)
- if parse_result.error:
- _print_errors(
- [error.make_error_from_parse_error(file_name, parse_result.error)],
- {file_name: source_code},
- flags)
- continue
+ parse_result = parser.parse_module(tokens)
+ if parse_result.error:
+ _print_errors(
+ [error.make_error_from_parse_error(file_name, parse_result.error)],
+ {file_name: source_code},
+ flags,
+ )
+ continue
- formatted_text = format_emb.format_emboss_parse_tree(
- parse_result.parse_tree,
- format_emb.Config(show_line_types=flags.debug_show_line_types,
- indent_width=flags.indent))
+ formatted_text = format_emb.format_emboss_parse_tree(
+ parse_result.parse_tree,
+ format_emb.Config(
+ show_line_types=flags.debug_show_line_types, indent_width=flags.indent
+ ),
+ )
- if flags.check_result and not flags.debug_show_line_types:
- errors = format_emb.sanity_check_format_result(formatted_text,
- source_code)
- if errors:
- for e in errors:
- print(e, file=sys.stderr)
- continue
+ if flags.check_result and not flags.debug_show_line_types:
+ errors = format_emb.sanity_check_format_result(formatted_text, source_code)
+ if errors:
+ for e in errors:
+ print(e, file=sys.stderr)
+ continue
- if flags.edit_in_place:
- with open(file_name, 'w') as f:
- f.write(formatted_text)
- else:
- sys.stdout.write(formatted_text)
+ if flags.edit_in_place:
+ with open(file_name, "w") as f:
+ f.write(formatted_text)
+ else:
+ sys.stdout.write(formatted_text)
- return 0
+ return 0
-if __name__ == '__main__':
- sys.exit(main(sys.argv))
+if __name__ == "__main__":
+ sys.exit(main(sys.argv))
diff --git a/compiler/front_end/format_emb.py b/compiler/front_end/format_emb.py
index e4b5c23..df9bcf3 100644
--- a/compiler/front_end/format_emb.py
+++ b/compiler/front_end/format_emb.py
@@ -28,27 +28,26 @@
from compiler.util import parser_types
-class Config(collections.namedtuple('Config',
- ['indent_width', 'show_line_types'])):
- """Configuration for formatting."""
+class Config(collections.namedtuple("Config", ["indent_width", "show_line_types"])):
+ """Configuration for formatting."""
- def __new__(cls, indent_width=2, show_line_types=False):
- return super(cls, Config).__new__(cls, indent_width, show_line_types)
+ def __new__(cls, indent_width=2, show_line_types=False):
+ return super(cls, Config).__new__(cls, indent_width, show_line_types)
-class _Row(collections.namedtuple('Row', ['name', 'columns', 'indent'])):
- """Structured contents of a single line."""
+class _Row(collections.namedtuple("Row", ["name", "columns", "indent"])):
+ """Structured contents of a single line."""
- def __new__(cls, name, columns=None, indent=0):
- return super(cls, _Row).__new__(cls, name, tuple(columns or []), indent)
+ def __new__(cls, name, columns=None, indent=0):
+ return super(cls, _Row).__new__(cls, name, tuple(columns or []), indent)
-class _Block(collections.namedtuple('Block', ['prefix', 'header', 'body'])):
- """Structured block of multiple lines."""
+class _Block(collections.namedtuple("Block", ["prefix", "header", "body"])):
+ """Structured block of multiple lines."""
- def __new__(cls, prefix, header, body):
- assert header
- return super(cls, _Block).__new__(cls, prefix, header, body)
+ def __new__(cls, prefix, header, body):
+ assert header
+ return super(cls, _Block).__new__(cls, prefix, header, body)
# Map of productions to their formatters.
@@ -56,280 +55,295 @@
def format_emboss_parse_tree(parse_tree, config, used_productions=None):
- """Formats Emboss source code.
+ """Formats Emboss source code.
- Arguments:
- parse_tree: A parse tree of an Emboss source file.
- config: A Config tuple with formatting options.
- used_productions: An optional set to which all used productions will be
- added. Intended for use by test code to ensure full production
- coverage.
+ Arguments:
+ parse_tree: A parse tree of an Emboss source file.
+ config: A Config tuple with formatting options.
+ used_productions: An optional set to which all used productions will be
+ added. Intended for use by test code to ensure full production
+ coverage.
- Returns:
- A string of the reformatted source text.
- """
- if hasattr(parse_tree, 'children'):
- parsed_children = [format_emboss_parse_tree(child, config, used_productions)
- for child in parse_tree.children]
- args = parsed_children + [config]
- if used_productions is not None:
- used_productions.add(parse_tree.production)
- return _formatters[parse_tree.production](*args)
- else:
- assert isinstance(parse_tree, parser_types.Token), str(parse_tree)
- return parse_tree.text
+ Returns:
+ A string of the reformatted source text.
+ """
+ if hasattr(parse_tree, "children"):
+ parsed_children = [
+ format_emboss_parse_tree(child, config, used_productions)
+ for child in parse_tree.children
+ ]
+ args = parsed_children + [config]
+ if used_productions is not None:
+ used_productions.add(parse_tree.production)
+ return _formatters[parse_tree.production](*args)
+ else:
+ assert isinstance(parse_tree, parser_types.Token), str(parse_tree)
+ return parse_tree.text
def sanity_check_format_result(formatted_text, original_text):
- """Checks that the given texts are equivalent."""
- # The texts are considered equivalent if they tokenize to the same token
- # stream, except that:
- #
- # Multiple consecutive newline tokens are equivalent to a single newline
- # token.
- #
- # Extra newline tokens at the start of the stream should be ignored.
- #
- # Whitespace at the start or end of a token should be ignored. This matters
- # for documentation and comment tokens, which may have had trailing whitespace
- # in the original text, and for indent tokens, which may contain a different
- # number of space and/or tab characters.
- original_tokens, errors = tokenizer.tokenize(original_text, '')
- if errors:
- return ['BUG: original text is not tokenizable: {!r}'.format(errors)]
+ """Checks that the given texts are equivalent."""
+ # The texts are considered equivalent if they tokenize to the same token
+ # stream, except that:
+ #
+ # Multiple consecutive newline tokens are equivalent to a single newline
+ # token.
+ #
+ # Extra newline tokens at the start of the stream should be ignored.
+ #
+ # Whitespace at the start or end of a token should be ignored. This matters
+ # for documentation and comment tokens, which may have had trailing whitespace
+ # in the original text, and for indent tokens, which may contain a different
+ # number of space and/or tab characters.
+ original_tokens, errors = tokenizer.tokenize(original_text, "")
+ if errors:
+ return ["BUG: original text is not tokenizable: {!r}".format(errors)]
- formatted_tokens, errors = tokenizer.tokenize(formatted_text, '')
- if errors:
- return ['BUG: formatted text is not tokenizable: {!r}'.format(errors)]
+ formatted_tokens, errors = tokenizer.tokenize(formatted_text, "")
+ if errors:
+ return ["BUG: formatted text is not tokenizable: {!r}".format(errors)]
- o_tokens = _collapse_newline_tokens(original_tokens)
- f_tokens = _collapse_newline_tokens(formatted_tokens)
- for i in range(len(o_tokens)):
- if (o_tokens[i].symbol != f_tokens[i].symbol or
- o_tokens[i].text.strip() != f_tokens[i].text.strip()):
- return ['BUG: Symbol {} differs: {!r} vs {!r}'.format(i, o_tokens[i],
- f_tokens[i])]
- return []
+ o_tokens = _collapse_newline_tokens(original_tokens)
+ f_tokens = _collapse_newline_tokens(formatted_tokens)
+ for i in range(len(o_tokens)):
+ if (
+ o_tokens[i].symbol != f_tokens[i].symbol
+ or o_tokens[i].text.strip() != f_tokens[i].text.strip()
+ ):
+ return [
+ "BUG: Symbol {} differs: {!r} vs {!r}".format(
+ i, o_tokens[i], f_tokens[i]
+ )
+ ]
+ return []
def _collapse_newline_tokens(token_list):
- r"""Collapses multiple consecutive "\\n" tokens into a single newline."""
- result = []
- for symbol, group in itertools.groupby(token_list, lambda x: x.symbol):
- if symbol == '"\\n"':
- # Skip all newlines if they are at the start, otherwise add a single
- # newline for each consecutive run of newlines.
- if result:
- result.append(list(group)[0])
- else:
- result.extend(group)
- return result
+ r"""Collapses multiple consecutive "\\n" tokens into a single newline."""
+ result = []
+ for symbol, group in itertools.groupby(token_list, lambda x: x.symbol):
+ if symbol == '"\\n"':
+ # Skip all newlines if they are at the start, otherwise add a single
+ # newline for each consecutive run of newlines.
+ if result:
+ result.append(list(group)[0])
+ else:
+ result.extend(group)
+ return result
def _indent_row(row):
- """Adds one level of indent to the given row, returning a new row."""
- assert isinstance(row, _Row), repr(row)
- return _Row(name=row.name,
- columns=row.columns,
- indent=row.indent + 1)
+ """Adds one level of indent to the given row, returning a new row."""
+ assert isinstance(row, _Row), repr(row)
+ return _Row(name=row.name, columns=row.columns, indent=row.indent + 1)
def _indent_rows(rows):
- """Adds one level of indent to the given rows, returning a new list."""
- return list(map(_indent_row, rows))
+ """Adds one level of indent to the given rows, returning a new list."""
+ return list(map(_indent_row, rows))
def _indent_blocks(blocks):
- """Adds one level of indent to the given blocks, returning a new list."""
- return [_Block(prefix=_indent_rows(block.prefix),
- header=_indent_row(block.header),
- body=_indent_rows(block.body))
- for block in blocks]
+ """Adds one level of indent to the given blocks, returning a new list."""
+ return [
+ _Block(
+ prefix=_indent_rows(block.prefix),
+ header=_indent_row(block.header),
+ body=_indent_rows(block.body),
+ )
+ for block in blocks
+ ]
def _intersperse(interspersed, sections):
- """Intersperses `interspersed` between non-empty `sections`."""
- result = []
- for section in sections:
- if section:
- if result:
- result.extend(interspersed)
- result.extend(section)
- return result
+ """Intersperses `interspersed` between non-empty `sections`."""
+ result = []
+ for section in sections:
+ if section:
+ if result:
+ result.extend(interspersed)
+ result.extend(section)
+ return result
def _should_add_blank_lines(blocks):
- """Returns true if blank lines should be added between blocks."""
- other_non_empty_lines = 0
- last_non_empty_lines = 0
- for block in blocks:
- last_non_empty_lines = len([line for line in
- block.body + block.prefix
- if line.columns])
- other_non_empty_lines += last_non_empty_lines
- # Vertical spaces should be added if there are more interior
- # non-empty-non-header lines than header lines.
- return len(blocks) <= other_non_empty_lines - last_non_empty_lines
+ """Returns true if blank lines should be added between blocks."""
+ other_non_empty_lines = 0
+ last_non_empty_lines = 0
+ for block in blocks:
+ last_non_empty_lines = len(
+ [line for line in block.body + block.prefix if line.columns]
+ )
+ other_non_empty_lines += last_non_empty_lines
+ # Vertical spaces should be added if there are more interior
+ # non-empty-non-header lines than header lines.
+ return len(blocks) <= other_non_empty_lines - last_non_empty_lines
def _columnize(blocks, indent_width, indent_columns=1):
- """Aligns columns in the header rows of the given blocks.
+ """Aligns columns in the header rows of the given blocks.
- The `indent_columns` argument is used to determine how many columns should be
- indented. With `indent_columns == 1`, the result would be:
+ The `indent_columns` argument is used to determine how many columns should be
+ indented. With `indent_columns == 1`, the result would be:
- AA BB CC
- AAA BBB CCC
- A B C
+ AA BB CC
+ AAA BBB CCC
+ A B C
- With `indent_columns == 2`:
+ With `indent_columns == 2`:
- AA BB CC
- AAA BBB CCC
- A B C
+ AA BB CC
+ AAA BBB CCC
+ A B C
- With `indent_columns == 1`, only the first column is indented compared to
- surrounding rows; with `indent_columns == 2`, both the first and second
- columns are indented.
+ With `indent_columns == 1`, only the first column is indented compared to
+ surrounding rows; with `indent_columns == 2`, both the first and second
+ columns are indented.
- Arguments:
- blocks: A list of _Blocks to columnize.
- indent_width: The number of spaces per level of indent.
- indent_columns: The number of columns to indent.
+ Arguments:
+ blocks: A list of _Blocks to columnize.
+ indent_width: The number of spaces per level of indent.
+ indent_columns: The number of columns to indent.
- Returns:
- A list of _Rows of the prefix, header, and body _Rows of each block, where
- the header _Rows of each type have had their columns aligned.
- """
- single_width_separators = {'enum-value': {0, 1}, 'field': {0}}
- # For each type of row, figure out how many characters each column needs.
- row_types = collections.defaultdict(
- lambda: collections.defaultdict(lambda: 0))
- for block in blocks:
- max_lengths = row_types[block.header.name]
- for i in range(len(block.header.columns)):
- if i == indent_columns - 1:
- adjustment = block.header.indent * indent_width
- else:
- adjustment = 0
- max_lengths[i] = max(max_lengths[i],
- len(block.header.columns[i]) + adjustment)
+ Returns:
+ A list of _Rows of the prefix, header, and body _Rows of each block, where
+ the header _Rows of each type have had their columns aligned.
+ """
+ single_width_separators = {"enum-value": {0, 1}, "field": {0}}
+ # For each type of row, figure out how many characters each column needs.
+ row_types = collections.defaultdict(lambda: collections.defaultdict(lambda: 0))
+ for block in blocks:
+ max_lengths = row_types[block.header.name]
+ for i in range(len(block.header.columns)):
+ if i == indent_columns - 1:
+ adjustment = block.header.indent * indent_width
+ else:
+ adjustment = 0
+ max_lengths[i] = max(
+ max_lengths[i], len(block.header.columns[i]) + adjustment
+ )
- assert len(row_types) < 3
+ assert len(row_types) < 3
- # Then, for each row, actually columnize it.
- result = []
- for block in blocks:
- columns = []
- for i in range(len(block.header.columns)):
- column_width = row_types[block.header.name][i]
- if column_width == 0:
- # Zero-width columns are entirely omitted, including their column
- # separators.
- pass
- else:
- if i == indent_columns - 1:
- # This function only performs the right padding for each column.
- # Since the left padding for indent will be added later, the
- # corresponding space needs to be removed from the right padding of
- # the first column.
- column_width -= block.header.indent * indent_width
- if i in single_width_separators.get(block.header.name, []):
- # Only one space around the "=" in enum values and between the start
- # and size in field locations.
- column_width += 1
- else:
- column_width += 2
- columns.append(block.header.columns[i].ljust(column_width))
- result.append(block.prefix + [_Row(block.header.name,
- [''.join(columns).rstrip()],
- block.header.indent)] + block.body)
- return result
+ # Then, for each row, actually columnize it.
+ result = []
+ for block in blocks:
+ columns = []
+ for i in range(len(block.header.columns)):
+ column_width = row_types[block.header.name][i]
+ if column_width == 0:
+ # Zero-width columns are entirely omitted, including their column
+ # separators.
+ pass
+ else:
+ if i == indent_columns - 1:
+ # This function only performs the right padding for each column.
+ # Since the left padding for indent will be added later, the
+ # corresponding space needs to be removed from the right padding of
+ # the first column.
+ column_width -= block.header.indent * indent_width
+ if i in single_width_separators.get(block.header.name, []):
+ # Only one space around the "=" in enum values and between the start
+ # and size in field locations.
+ column_width += 1
+ else:
+ column_width += 2
+ columns.append(block.header.columns[i].ljust(column_width))
+ result.append(
+ block.prefix
+ + [
+ _Row(
+ block.header.name, ["".join(columns).rstrip()], block.header.indent
+ )
+ ]
+ + block.body
+ )
+ return result
def _indent_blanks_and_comments(rows):
- """Indents blank and comment lines to match the next non-blank line."""
- result = []
- previous_indent = 0
- for row in reversed(rows):
- if not ''.join(row.columns) or row.name == 'comment':
- result.append(_Row(row.name, row.columns, previous_indent))
- else:
- result.append(row)
- previous_indent = row.indent
- return reversed(result)
+ """Indents blank and comment lines to match the next non-blank line."""
+ result = []
+ previous_indent = 0
+ for row in reversed(rows):
+ if not "".join(row.columns) or row.name == "comment":
+ result.append(_Row(row.name, row.columns, previous_indent))
+ else:
+ result.append(row)
+ previous_indent = row.indent
+ return reversed(result)
def _add_blank_rows_on_dedent(rows):
- """Adds blank rows before dedented lines, where needed."""
- result = []
- previous_indent = 0
- previous_row_was_blank = True
- for row in rows:
- row_is_blank = not ''.join(row.columns)
- found_dedent = previous_indent > row.indent
- if found_dedent and not previous_row_was_blank and not row_is_blank:
- result.append(_Row('dedent-space', [], row.indent))
- result.append(row)
- previous_indent = row.indent
- previous_row_was_blank = row_is_blank
- return result
+ """Adds blank rows before dedented lines, where needed."""
+ result = []
+ previous_indent = 0
+ previous_row_was_blank = True
+ for row in rows:
+ row_is_blank = not "".join(row.columns)
+ found_dedent = previous_indent > row.indent
+ if found_dedent and not previous_row_was_blank and not row_is_blank:
+ result.append(_Row("dedent-space", [], row.indent))
+ result.append(row)
+ previous_indent = row.indent
+ previous_row_was_blank = row_is_blank
+ return result
def _render_row_to_text(row, indent_width):
- assert len(row.columns) < 2, '{!r}'.format(row)
- text = ' ' * indent_width * row.indent
- text += ''.join(row.columns)
- return text.rstrip()
+ assert len(row.columns) < 2, "{!r}".format(row)
+ text = " " * indent_width * row.indent
+ text += "".join(row.columns)
+ return text.rstrip()
def _render_rows_to_text(rows, indent_width, show_line_types):
- max_row_name_len = max([0] + [len(row.name) for row in rows])
- flattened_rows = []
- for row in rows:
- row_text = _render_row_to_text(row, indent_width)
- if show_line_types:
- row_text = row.name.ljust(max_row_name_len) + '|' + row_text
- flattened_rows.append(row_text)
- return '\n'.join(flattened_rows + [''])
+ max_row_name_len = max([0] + [len(row.name) for row in rows])
+ flattened_rows = []
+ for row in rows:
+ row_text = _render_row_to_text(row, indent_width)
+ if show_line_types:
+ row_text = row.name.ljust(max_row_name_len) + "|" + row_text
+ flattened_rows.append(row_text)
+ return "\n".join(flattened_rows + [""])
def _check_productions():
- """Asserts that the productions in this module match those in module_ir."""
- productions_ok = True
- for production in module_ir.PRODUCTIONS:
- if production not in _formatters:
- productions_ok = False
- print('@_formats({!r})'.format(str(production)))
+ """Asserts that the productions in this module match those in module_ir."""
+ productions_ok = True
+ for production in module_ir.PRODUCTIONS:
+ if production not in _formatters:
+ productions_ok = False
+ print("@_formats({!r})".format(str(production)))
- for production in _formatters:
- if production not in module_ir.PRODUCTIONS:
- productions_ok = False
- print('not @_formats({!r})'.format(str(production)))
+ for production in _formatters:
+ if production not in module_ir.PRODUCTIONS:
+ productions_ok = False
+ print("not @_formats({!r})".format(str(production)))
- assert productions_ok, 'Grammar mismatch.'
+ assert productions_ok, "Grammar mismatch."
def _formats_with_config(production_text):
- """Marks a function as a formatter requiring a config argument."""
- production = parser_types.Production.parse(production_text)
+ """Marks a function as a formatter requiring a config argument."""
+ production = parser_types.Production.parse(production_text)
- def formats(f):
- assert production not in _formatters, production
- _formatters[production] = f
- return f
+ def formats(f):
+ assert production not in _formatters, production
+ _formatters[production] = f
+ return f
- return formats
+ return formats
def _formats(production_text):
- """Marks a function as the formatter for a particular production."""
+ """Marks a function as the formatter for a particular production."""
- def strip_config_argument(f):
- _formats_with_config(production_text)(lambda *a, **kw: f(*a[:-1], **kw))
- return f
+ def strip_config_argument(f):
+ _formats_with_config(production_text)(lambda *a, **kw: f(*a[:-1], **kw))
+ return f
- return strip_config_argument
+ return strip_config_argument
################################################################################
@@ -358,388 +372,490 @@
# strings.
-@_formats_with_config('module -> comment-line* doc-line* import-line*'
- ' attribute-line* type-definition*')
+@_formats_with_config(
+ "module -> comment-line* doc-line* import-line*"
+ " attribute-line* type-definition*"
+)
def _module(comments, docs, imports, attributes, types, config):
- """Performs top-level formatting for an Emboss source file."""
+ """Performs top-level formatting for an Emboss source file."""
- # The top-level sections other than types should be separated by single lines.
- header_rows = _intersperse(
- [_Row('section-break')],
- [_strip_empty_leading_trailing_comment_lines(comments), docs, imports,
- attributes])
+ # The top-level sections other than types should be separated by single lines.
+ header_rows = _intersperse(
+ [_Row("section-break")],
+ [
+ _strip_empty_leading_trailing_comment_lines(comments),
+ docs,
+ imports,
+ attributes,
+ ],
+ )
- # Top-level types should be separated by double lines from themselves and from
- # the header rows.
- rows = _intersperse(
- [_Row('top-type-separator'), _Row('top-type-separator')],
- [header_rows] + types)
+ # Top-level types should be separated by double lines from themselves and from
+ # the header rows.
+ rows = _intersperse(
+ [_Row("top-type-separator"), _Row("top-type-separator")], [header_rows] + types
+ )
- # Final fixups.
- rows = _indent_blanks_and_comments(rows)
- rows = _add_blank_rows_on_dedent(rows)
- return _render_rows_to_text(rows, config.indent_width, config.show_line_types)
+ # Final fixups.
+ rows = _indent_blanks_and_comments(rows)
+ rows = _add_blank_rows_on_dedent(rows)
+ return _render_rows_to_text(rows, config.indent_width, config.show_line_types)
-@_formats('doc-line -> doc Comment? eol')
+@_formats("doc-line -> doc Comment? eol")
def _doc_line(doc, comment, eol):
- assert not comment, 'Comment should not be possible on the same line as doc.'
- return [_Row('doc', [doc])] + eol
+ assert not comment, "Comment should not be possible on the same line as doc."
+ return [_Row("doc", [doc])] + eol
-@_formats('import-line -> "import" string-constant "as" snake-word Comment?'
- ' eol')
+@_formats(
+ 'import-line -> "import" string-constant "as" snake-word Comment?'
+ " eol"
+)
def _import_line(import_, filename, as_, name, comment, eol):
- return [_Row('import', ['{} {} {} {} {}'.format(
- import_, filename, as_, name, comment)])] + eol
+ return [
+ _Row(
+ "import", ["{} {} {} {} {}".format(import_, filename, as_, name, comment)]
+ )
+ ] + eol
-@_formats('attribute-line -> attribute Comment? eol')
+@_formats("attribute-line -> attribute Comment? eol")
def _attribute_line(attribute, comment, eol):
- return [_Row('attribute', ['{} {}'.format(attribute, comment)])] + eol
+ return [_Row("attribute", ["{} {}".format(attribute, comment)])] + eol
-@_formats('attribute -> "[" attribute-context? "$default"? snake-word ":"'
- ' attribute-value "]"')
+@_formats(
+ 'attribute -> "[" attribute-context? "$default"? snake-word ":"'
+ ' attribute-value "]"'
+)
def _attribute(open_, context, default, name, colon, value, close):
- return ''.join([open_,
- _concatenate_with_spaces(context, default, name + colon,
- value),
- close])
+ return "".join(
+ [open_, _concatenate_with_spaces(context, default, name + colon, value), close]
+ )
@_formats('parameter-definition -> snake-name ":" type')
def _parameter_definition(name, colon, type_specifier):
- return '{}{} {}'.format(name, colon, type_specifier)
+ return "{}{} {}".format(name, colon, type_specifier)
-@_formats('type-definition* -> type-definition type-definition*')
+@_formats("type-definition* -> type-definition type-definition*")
def _type_defitinions(definition, definitions):
- return [definition] + definitions
+ return [definition] + definitions
-@_formats('bits -> "bits" type-name delimited-parameter-definition-list? ":"'
- ' Comment? eol bits-body')
-@_formats('struct -> "struct" type-name delimited-parameter-definition-list?'
- ' ":" Comment? eol struct-body')
+@_formats(
+ 'bits -> "bits" type-name delimited-parameter-definition-list? ":"'
+ " Comment? eol bits-body"
+)
+@_formats(
+ 'struct -> "struct" type-name delimited-parameter-definition-list?'
+ ' ":" Comment? eol struct-body'
+)
def _structure_type(struct, name, parameters, colon, comment, eol, body):
- return ([_Row('type-header',
- ['{} {}{}{} {}'.format(
- struct, name, parameters, colon, comment)])] +
- eol + body)
+ return (
+ [
+ _Row(
+ "type-header",
+ ["{} {}{}{} {}".format(struct, name, parameters, colon, comment)],
+ )
+ ]
+ + eol
+ + body
+ )
@_formats('enum -> "enum" type-name ":" Comment? eol enum-body')
@_formats('external -> "external" type-name ":" Comment? eol external-body')
def _type(struct, name, colon, comment, eol, body):
- return ([_Row('type-header',
- ['{} {}{} {}'.format(struct, name, colon, comment)])] +
- eol + body)
+ return (
+ [_Row("type-header", ["{} {}{} {}".format(struct, name, colon, comment)])]
+ + eol
+ + body
+ )
-@_formats_with_config('bits-body -> Indent doc-line* attribute-line*'
- ' type-definition* bits-field-block Dedent')
@_formats_with_config(
- 'struct-body -> Indent doc-line* attribute-line*'
- ' type-definition* struct-field-block Dedent')
-def _structure_body(indent, docs, attributes, type_definitions, fields, dedent,
- config):
- del indent, dedent # Unused.
- spacing = [_Row('field-separator')] if _should_add_blank_lines(fields) else []
- columnized_fields = _columnize(fields, config.indent_width, indent_columns=2)
- return _indent_rows(_intersperse(
- spacing, [docs, attributes] + type_definitions + columnized_fields))
+ "bits-body -> Indent doc-line* attribute-line*"
+ " type-definition* bits-field-block Dedent"
+)
+@_formats_with_config(
+ "struct-body -> Indent doc-line* attribute-line*"
+ " type-definition* struct-field-block Dedent"
+)
+def _structure_body(indent, docs, attributes, type_definitions, fields, dedent, config):
+ del indent, dedent # Unused.
+ spacing = [_Row("field-separator")] if _should_add_blank_lines(fields) else []
+ columnized_fields = _columnize(fields, config.indent_width, indent_columns=2)
+ return _indent_rows(
+ _intersperse(spacing, [docs, attributes] + type_definitions + columnized_fields)
+ )
@_formats('field-location -> expression "[" "+" expression "]"')
def _field_location(start, open_bracket, plus, size, close_bracket):
- return [start, open_bracket + plus + size + close_bracket]
+ return [start, open_bracket + plus + size + close_bracket]
-@_formats('anonymous-bits-field-block -> conditional-anonymous-bits-field-block'
- ' anonymous-bits-field-block')
-@_formats('anonymous-bits-field-block -> unconditional-anonymous-bits-field'
- ' anonymous-bits-field-block')
-@_formats('bits-field-block -> conditional-bits-field-block bits-field-block')
-@_formats('bits-field-block -> unconditional-bits-field bits-field-block')
-@_formats('struct-field-block -> conditional-struct-field-block'
- ' struct-field-block')
-@_formats('struct-field-block -> unconditional-struct-field struct-field-block')
-@_formats('unconditional-anonymous-bits-field* ->'
- ' unconditional-anonymous-bits-field'
- ' unconditional-anonymous-bits-field*')
-@_formats('unconditional-anonymous-bits-field+ ->'
- ' unconditional-anonymous-bits-field'
- ' unconditional-anonymous-bits-field*')
-@_formats('unconditional-bits-field* -> unconditional-bits-field'
- ' unconditional-bits-field*')
-@_formats('unconditional-bits-field+ -> unconditional-bits-field'
- ' unconditional-bits-field*')
-@_formats('unconditional-struct-field* -> unconditional-struct-field'
- ' unconditional-struct-field*')
-@_formats('unconditional-struct-field+ -> unconditional-struct-field'
- ' unconditional-struct-field*')
-def _structure_block(field, block):
- """Prepends field to block."""
- return field + block
-
-
-@_formats('virtual-field -> "let" snake-name "=" expression Comment? eol'
- ' field-body?')
-def _virtual_field(let_keyword, name, equals, value, comment, eol, body):
- # This formatting doesn't look the best when there are blocks of several
- # virtual fields next to each other, but works pretty well when they're
- # intermixed with physical fields. It's probably good enough for now, since
- # there aren't (yet) any virtual fields in real .embs, and will probably only
- # be a few in the near future.
- return [_Block([],
- _Row('virtual-field',
- [_concatenate_with(
- ' ',
- _concatenate_with_spaces(let_keyword, name, equals,
- value),
- comment)]),
- eol + body)]
-
-
-@_formats('field -> field-location type snake-name abbreviation?'
- ' attribute* doc? Comment? eol field-body?')
-def _unconditional_field(location, type_, name, abbreviation, attributes, doc,
- comment, eol, body):
- return [_Block([],
- _Row('field',
- location + [type_,
- _concatenate_with_spaces(name, abbreviation),
- attributes, doc, comment]),
- eol + body)]
-
-
-@_formats('field-body -> Indent doc-line* attribute-line* Dedent')
-def _field_body(indent, docs, attributes, dedent):
- del indent, dedent # Unused
- return _indent_rows(docs + attributes)
-
-
-@_formats('anonymous-bits-field-definition ->'
- ' field-location "bits" ":" Comment? eol anonymous-bits-body')
-def _inline_bits(location, bits, colon, comment, eol, body):
- # Even though an anonymous bits field technically defines a new, anonymous
- # type, conceptually it's more like defining a bunch of fields on the
- # surrounding type, so it is treated as an inline list of blocks, instead of
- # being separately formatted.
- header_row = _Row('field', [location[0], location[1] + ' ' + bits + colon,
- '', '', '', '', comment])
- return ([_Block([], header_row, eol + body.header_lines)] +
- body.field_blocks)
-
-
-@_formats('inline-enum-field-definition ->'
- ' field-location "enum" snake-name abbreviation? ":" Comment? eol'
- ' enum-body')
@_formats(
- 'inline-struct-field-definition ->'
+ "anonymous-bits-field-block -> conditional-anonymous-bits-field-block"
+ " anonymous-bits-field-block"
+)
+@_formats(
+ "anonymous-bits-field-block -> unconditional-anonymous-bits-field"
+ " anonymous-bits-field-block"
+)
+@_formats("bits-field-block -> conditional-bits-field-block bits-field-block")
+@_formats("bits-field-block -> unconditional-bits-field bits-field-block")
+@_formats(
+ "struct-field-block -> conditional-struct-field-block"
+ " struct-field-block"
+)
+@_formats("struct-field-block -> unconditional-struct-field struct-field-block")
+@_formats(
+ "unconditional-anonymous-bits-field* ->"
+ " unconditional-anonymous-bits-field"
+ " unconditional-anonymous-bits-field*"
+)
+@_formats(
+ "unconditional-anonymous-bits-field+ ->"
+ " unconditional-anonymous-bits-field"
+ " unconditional-anonymous-bits-field*"
+)
+@_formats(
+ "unconditional-bits-field* -> unconditional-bits-field"
+ " unconditional-bits-field*"
+)
+@_formats(
+ "unconditional-bits-field+ -> unconditional-bits-field"
+ " unconditional-bits-field*"
+)
+@_formats(
+ "unconditional-struct-field* -> unconditional-struct-field"
+ " unconditional-struct-field*"
+)
+@_formats(
+ "unconditional-struct-field+ -> unconditional-struct-field"
+ " unconditional-struct-field*"
+)
+def _structure_block(field, block):
+ """Prepends field to block."""
+ return field + block
+
+
+@_formats(
+ 'virtual-field -> "let" snake-name "=" expression Comment? eol'
+ " field-body?"
+)
+def _virtual_field(let_keyword, name, equals, value, comment, eol, body):
+ # This formatting doesn't look the best when there are blocks of several
+ # virtual fields next to each other, but works pretty well when they're
+ # intermixed with physical fields. It's probably good enough for now, since
+ # there aren't (yet) any virtual fields in real .embs, and will probably only
+ # be a few in the near future.
+ return [
+ _Block(
+ [],
+ _Row(
+ "virtual-field",
+ [
+ _concatenate_with(
+ " ",
+ _concatenate_with_spaces(let_keyword, name, equals, value),
+ comment,
+ )
+ ],
+ ),
+ eol + body,
+ )
+ ]
+
+
+@_formats(
+ "field -> field-location type snake-name abbreviation?"
+ " attribute* doc? Comment? eol field-body?"
+)
+def _unconditional_field(
+ location, type_, name, abbreviation, attributes, doc, comment, eol, body
+):
+ return [
+ _Block(
+ [],
+ _Row(
+ "field",
+ location
+ + [
+ type_,
+ _concatenate_with_spaces(name, abbreviation),
+ attributes,
+ doc,
+ comment,
+ ],
+ ),
+ eol + body,
+ )
+ ]
+
+
+@_formats("field-body -> Indent doc-line* attribute-line* Dedent")
+def _field_body(indent, docs, attributes, dedent):
+ del indent, dedent # Unused
+ return _indent_rows(docs + attributes)
+
+
+@_formats(
+ "anonymous-bits-field-definition ->"
+ ' field-location "bits" ":" Comment? eol anonymous-bits-body'
+)
+def _inline_bits(location, bits, colon, comment, eol, body):
+ # Even though an anonymous bits field technically defines a new, anonymous
+ # type, conceptually it's more like defining a bunch of fields on the
+ # surrounding type, so it is treated as an inline list of blocks, instead of
+ # being separately formatted.
+ header_row = _Row(
+ "field",
+ [location[0], location[1] + " " + bits + colon, "", "", "", "", comment],
+ )
+ return [_Block([], header_row, eol + body.header_lines)] + body.field_blocks
+
+
+@_formats(
+ "inline-enum-field-definition ->"
+ ' field-location "enum" snake-name abbreviation? ":" Comment? eol'
+ " enum-body"
+)
+@_formats(
+ "inline-struct-field-definition ->"
' field-location "struct" snake-name abbreviation? ":" Comment? eol'
- ' struct-body')
-@_formats('inline-bits-field-definition ->'
- ' field-location "bits" snake-name abbreviation? ":" Comment? eol'
- ' bits-body')
-def _inline_type(location, keyword, name, abbreviation, colon, comment, eol,
- body):
- """Formats an inline type in a struct or bits."""
- header_row = _Row(
- 'field', location + [keyword,
- _concatenate_with_spaces(name, abbreviation) + colon,
- '', '', comment])
- return [_Block([], header_row, eol + body)]
+ " struct-body"
+)
+@_formats(
+ "inline-bits-field-definition ->"
+ ' field-location "bits" snake-name abbreviation? ":" Comment? eol'
+ " bits-body"
+)
+def _inline_type(location, keyword, name, abbreviation, colon, comment, eol, body):
+ """Formats an inline type in a struct or bits."""
+ header_row = _Row(
+ "field",
+ location
+ + [
+ keyword,
+ _concatenate_with_spaces(name, abbreviation) + colon,
+ "",
+ "",
+ comment,
+ ],
+ )
+ return [_Block([], header_row, eol + body)]
-@_formats('conditional-struct-field-block -> "if" expression ":" Comment? eol'
- ' Indent unconditional-struct-field+'
- ' Dedent')
-@_formats('conditional-bits-field-block -> "if" expression ":" Comment? eol'
- ' Indent unconditional-bits-field+'
- ' Dedent')
-@_formats('conditional-anonymous-bits-field-block ->'
- ' "if" expression ":" Comment? eol'
- ' Indent unconditional-anonymous-bits-field+ Dedent')
-def _conditional_field(if_, condition, colon, comment, eol, indent, body,
- dedent):
- """Formats an `if` construct."""
- del indent, dedent # Unused
- # The body of an 'if' should be columnized with the surrounding blocks, so
- # much like an inline 'bits', its body is treated as an inline list of blocks.
- header_row = _Row('if',
- ['{} {}{} {}'.format(if_, condition, colon, comment)])
- indented_body = _indent_blocks(body)
- assert indented_body, 'Expected body of if condition.'
- return [_Block([header_row] + eol + indented_body[0].prefix,
- indented_body[0].header,
- indented_body[0].body)] + indented_body[1:]
+@_formats(
+ 'conditional-struct-field-block -> "if" expression ":" Comment? eol'
+ " Indent unconditional-struct-field+"
+ " Dedent"
+)
+@_formats(
+ 'conditional-bits-field-block -> "if" expression ":" Comment? eol'
+ " Indent unconditional-bits-field+"
+ " Dedent"
+)
+@_formats(
+ "conditional-anonymous-bits-field-block ->"
+ ' "if" expression ":" Comment? eol'
+ " Indent unconditional-anonymous-bits-field+ Dedent"
+)
+def _conditional_field(if_, condition, colon, comment, eol, indent, body, dedent):
+ """Formats an `if` construct."""
+ del indent, dedent # Unused
+ # The body of an 'if' should be columnized with the surrounding blocks, so
+ # much like an inline 'bits', its body is treated as an inline list of blocks.
+ header_row = _Row("if", ["{} {}{} {}".format(if_, condition, colon, comment)])
+ indented_body = _indent_blocks(body)
+ assert indented_body, "Expected body of if condition."
+ return [
+ _Block(
+ [header_row] + eol + indented_body[0].prefix,
+ indented_body[0].header,
+ indented_body[0].body,
+ )
+ ] + indented_body[1:]
-_InlineBitsBodyType = collections.namedtuple('InlineBitsBodyType',
- ['header_lines', 'field_blocks'])
+_InlineBitsBodyType = collections.namedtuple(
+ "InlineBitsBodyType", ["header_lines", "field_blocks"]
+)
-@_formats('anonymous-bits-body ->'
- ' Indent attribute-line* anonymous-bits-field-block Dedent')
+@_formats(
+ "anonymous-bits-body ->"
+ " Indent attribute-line* anonymous-bits-field-block Dedent"
+)
def _inline_bits_body(indent, attributes, fields, dedent):
- del indent, dedent # Unused
- return _InlineBitsBodyType(header_lines=_indent_rows(attributes),
- field_blocks=_indent_blocks(fields))
+ del indent, dedent # Unused
+ return _InlineBitsBodyType(
+ header_lines=_indent_rows(attributes), field_blocks=_indent_blocks(fields)
+ )
@_formats_with_config(
- 'enum-body -> Indent doc-line* attribute-line* enum-value+'
- ' Dedent')
+ "enum-body -> Indent doc-line* attribute-line* enum-value+" " Dedent"
+)
def _enum_body(indent, docs, attributes, values, dedent, config):
- del indent, dedent # Unused
- spacing = [_Row('value-separator')] if _should_add_blank_lines(values) else []
- columnized_values = _columnize(values, config.indent_width)
- return _indent_rows(_intersperse(spacing,
- [docs, attributes] + columnized_values))
+ del indent, dedent # Unused
+ spacing = [_Row("value-separator")] if _should_add_blank_lines(values) else []
+ columnized_values = _columnize(values, config.indent_width)
+ return _indent_rows(_intersperse(spacing, [docs, attributes] + columnized_values))
-@_formats('enum-value* -> enum-value enum-value*')
-@_formats('enum-value+ -> enum-value enum-value*')
+@_formats("enum-value* -> enum-value enum-value*")
+@_formats("enum-value+ -> enum-value enum-value*")
def _enum_values(value, block):
- return value + block
+ return value + block
-@_formats('enum-value -> constant-name "=" expression attribute* doc? Comment? eol'
- ' enum-value-body?')
+@_formats(
+ 'enum-value -> constant-name "=" expression attribute* doc? Comment? eol'
+ " enum-value-body?"
+)
def _enum_value(name, equals, value, attributes, docs, comment, eol, body):
- return [_Block([], _Row('enum-value', [name, equals, value, attributes, docs, comment]),
- eol + body)]
+ return [
+ _Block(
+ [],
+ _Row("enum-value", [name, equals, value, attributes, docs, comment]),
+ eol + body,
+ )
+ ]
-@_formats('enum-value-body -> Indent doc-line* attribute-line* Dedent')
+@_formats("enum-value-body -> Indent doc-line* attribute-line* Dedent")
def _enum_value_body(indent, docs, attributes, dedent):
- del indent, dedent # Unused
- return _indent_rows(docs + attributes)
+ del indent, dedent # Unused
+ return _indent_rows(docs + attributes)
-@_formats('external-body -> Indent doc-line* attribute-line* Dedent')
+@_formats("external-body -> Indent doc-line* attribute-line* Dedent")
def _external_body(indent, docs, attributes, dedent):
- del indent, dedent # Unused
- return _indent_rows(_intersperse([_Row('section-break')], [docs, attributes]))
+ del indent, dedent # Unused
+ return _indent_rows(_intersperse([_Row("section-break")], [docs, attributes]))
@_formats('comment-line -> Comment? "\\n"')
def _comment_line(comment, eol):
- del eol # Unused
- if comment:
- return [_Row('comment', [comment])]
- else:
- return [_Row('comment')]
+ del eol # Unused
+ if comment:
+ return [_Row("comment", [comment])]
+ else:
+ return [_Row("comment")]
@_formats('eol -> "\\n" comment-line*')
def _eol(eol, comments):
- del eol # Unused
- return _strip_empty_leading_trailing_comment_lines(comments)
+ del eol # Unused
+ return _strip_empty_leading_trailing_comment_lines(comments)
def _strip_empty_leading_trailing_comment_lines(comments):
- first_non_empty_line = None
- last_non_empty_line = None
- for i in range(len(comments)):
- if comments[i].columns:
- if first_non_empty_line is None:
- first_non_empty_line = i
- last_non_empty_line = i
- if first_non_empty_line is None:
- return []
- else:
- return comments[first_non_empty_line:last_non_empty_line + 1]
+ first_non_empty_line = None
+ last_non_empty_line = None
+ for i in range(len(comments)):
+ if comments[i].columns:
+ if first_non_empty_line is None:
+ first_non_empty_line = i
+ last_non_empty_line = i
+ if first_non_empty_line is None:
+ return []
+ else:
+ return comments[first_non_empty_line : last_non_empty_line + 1]
-@_formats('attribute-line* -> ')
-@_formats('anonymous-bits-field-block -> ')
-@_formats('bits-field-block -> ')
-@_formats('comment-line* -> ')
-@_formats('doc-line* -> ')
-@_formats('enum-value* -> ')
-@_formats('enum-value-body? -> ')
-@_formats('field-body? -> ')
-@_formats('import-line* -> ')
-@_formats('struct-field-block -> ')
-@_formats('type-definition* -> ')
-@_formats('unconditional-anonymous-bits-field* -> ')
-@_formats('unconditional-bits-field* -> ')
-@_formats('unconditional-struct-field* -> ')
+@_formats("attribute-line* -> ")
+@_formats("anonymous-bits-field-block -> ")
+@_formats("bits-field-block -> ")
+@_formats("comment-line* -> ")
+@_formats("doc-line* -> ")
+@_formats("enum-value* -> ")
+@_formats("enum-value-body? -> ")
+@_formats("field-body? -> ")
+@_formats("import-line* -> ")
+@_formats("struct-field-block -> ")
+@_formats("type-definition* -> ")
+@_formats("unconditional-anonymous-bits-field* -> ")
+@_formats("unconditional-bits-field* -> ")
+@_formats("unconditional-struct-field* -> ")
def _empty_list():
- return []
+ return []
-@_formats('abbreviation? -> ')
-@_formats('additive-expression-right* -> ')
-@_formats('and-expression-right* -> ')
-@_formats('argument-list -> ')
-@_formats('array-length-specifier* -> ')
-@_formats('attribute* -> ')
-@_formats('attribute-context? -> ')
-@_formats('comma-then-expression* -> ')
-@_formats('Comment? -> ')
+@_formats("abbreviation? -> ")
+@_formats("additive-expression-right* -> ")
+@_formats("and-expression-right* -> ")
+@_formats("argument-list -> ")
+@_formats("array-length-specifier* -> ")
+@_formats("attribute* -> ")
+@_formats("attribute-context? -> ")
+@_formats("comma-then-expression* -> ")
+@_formats("Comment? -> ")
@_formats('"$default"? -> ')
-@_formats('delimited-argument-list? -> ')
-@_formats('delimited-parameter-definition-list? -> ')
-@_formats('doc? -> ')
-@_formats('equality-expression-right* -> ')
-@_formats('equality-or-greater-expression-right* -> ')
-@_formats('equality-or-less-expression-right* -> ')
-@_formats('field-reference-tail* -> ')
-@_formats('or-expression-right* -> ')
-@_formats('parameter-definition-list -> ')
-@_formats('parameter-definition-list-tail* -> ')
-@_formats('times-expression-right* -> ')
-@_formats('type-size-specifier? -> ')
+@_formats("delimited-argument-list? -> ")
+@_formats("delimited-parameter-definition-list? -> ")
+@_formats("doc? -> ")
+@_formats("equality-expression-right* -> ")
+@_formats("equality-or-greater-expression-right* -> ")
+@_formats("equality-or-less-expression-right* -> ")
+@_formats("field-reference-tail* -> ")
+@_formats("or-expression-right* -> ")
+@_formats("parameter-definition-list -> ")
+@_formats("parameter-definition-list-tail* -> ")
+@_formats("times-expression-right* -> ")
+@_formats("type-size-specifier? -> ")
def _empty_string():
- return ''
+ return ""
-@_formats('abbreviation? -> abbreviation')
+@_formats("abbreviation? -> abbreviation")
@_formats('additive-operator -> "-"')
@_formats('additive-operator -> "+"')
@_formats('and-operator -> "&&"')
-@_formats('attribute-context? -> attribute-context')
-@_formats('attribute-value -> expression')
-@_formats('attribute-value -> string-constant')
-@_formats('boolean-constant -> BooleanConstant')
-@_formats('bottom-expression -> boolean-constant')
-@_formats('bottom-expression -> builtin-reference')
-@_formats('bottom-expression -> constant-reference')
-@_formats('bottom-expression -> field-reference')
-@_formats('bottom-expression -> numeric-constant')
+@_formats("attribute-context? -> attribute-context")
+@_formats("attribute-value -> expression")
+@_formats("attribute-value -> string-constant")
+@_formats("boolean-constant -> BooleanConstant")
+@_formats("bottom-expression -> boolean-constant")
+@_formats("bottom-expression -> builtin-reference")
+@_formats("bottom-expression -> constant-reference")
+@_formats("bottom-expression -> field-reference")
+@_formats("bottom-expression -> numeric-constant")
@_formats('builtin-field-word -> "$max_size_in_bits"')
@_formats('builtin-field-word -> "$max_size_in_bytes"')
@_formats('builtin-field-word -> "$min_size_in_bits"')
@_formats('builtin-field-word -> "$min_size_in_bytes"')
@_formats('builtin-field-word -> "$size_in_bits"')
@_formats('builtin-field-word -> "$size_in_bytes"')
-@_formats('builtin-reference -> builtin-word')
+@_formats("builtin-reference -> builtin-word")
@_formats('builtin-word -> "$is_statically_sized"')
@_formats('builtin-word -> "$next"')
@_formats('builtin-word -> "$static_size_in_bits"')
-@_formats('choice-expression -> logical-expression')
-@_formats('Comment? -> Comment')
-@_formats('comparison-expression -> additive-expression')
-@_formats('constant-name -> constant-word')
-@_formats('constant-reference -> constant-reference-tail')
-@_formats('constant-reference-tail -> constant-word')
-@_formats('constant-word -> ShoutyWord')
+@_formats("choice-expression -> logical-expression")
+@_formats("Comment? -> Comment")
+@_formats("comparison-expression -> additive-expression")
+@_formats("constant-name -> constant-word")
+@_formats("constant-reference -> constant-reference-tail")
+@_formats("constant-reference-tail -> constant-word")
+@_formats("constant-word -> ShoutyWord")
@_formats('"$default"? -> "$default"')
-@_formats('delimited-argument-list? -> delimited-argument-list')
-@_formats('doc? -> doc')
-@_formats('doc -> Documentation')
-@_formats('enum-value-body? -> enum-value-body')
+@_formats("delimited-argument-list? -> delimited-argument-list")
+@_formats("doc? -> doc")
+@_formats("doc -> Documentation")
+@_formats("enum-value-body? -> enum-value-body")
@_formats('equality-operator -> "=="')
-@_formats('equality-or-greater-expression-right -> equality-expression-right')
-@_formats('equality-or-greater-expression-right -> greater-expression-right')
-@_formats('equality-or-less-expression-right -> equality-expression-right')
-@_formats('equality-or-less-expression-right -> less-expression-right')
-@_formats('expression -> choice-expression')
-@_formats('field-body? -> field-body')
+@_formats("equality-or-greater-expression-right -> equality-expression-right")
+@_formats("equality-or-greater-expression-right -> greater-expression-right")
+@_formats("equality-or-less-expression-right -> equality-expression-right")
+@_formats("equality-or-less-expression-right -> less-expression-right")
+@_formats("expression -> choice-expression")
+@_formats("field-body? -> field-body")
@_formats('function-name -> "$lower_bound"')
@_formats('function-name -> "$present"')
@_formats('function-name -> "$max"')
@@ -749,146 +865,186 @@
@_formats('inequality-operator -> "!="')
@_formats('less-operator -> "<="')
@_formats('less-operator -> "<"')
-@_formats('logical-expression -> and-expression')
-@_formats('logical-expression -> comparison-expression')
-@_formats('logical-expression -> or-expression')
+@_formats("logical-expression -> and-expression")
+@_formats("logical-expression -> comparison-expression")
+@_formats("logical-expression -> or-expression")
@_formats('multiplicative-operator -> "*"')
-@_formats('negation-expression -> bottom-expression')
-@_formats('numeric-constant -> Number')
+@_formats("negation-expression -> bottom-expression")
+@_formats("numeric-constant -> Number")
@_formats('or-operator -> "||"')
-@_formats('snake-name -> snake-word')
-@_formats('snake-reference -> builtin-field-word')
-@_formats('snake-reference -> snake-word')
-@_formats('snake-word -> SnakeWord')
-@_formats('string-constant -> String')
-@_formats('type-definition -> bits')
-@_formats('type-definition -> enum')
-@_formats('type-definition -> external')
-@_formats('type-definition -> struct')
-@_formats('type-name -> type-word')
-@_formats('type-reference-tail -> type-word')
-@_formats('type-reference -> type-reference-tail')
-@_formats('type-size-specifier? -> type-size-specifier')
-@_formats('type-word -> CamelWord')
-@_formats('unconditional-anonymous-bits-field -> field')
-@_formats('unconditional-anonymous-bits-field -> inline-bits-field-definition')
-@_formats('unconditional-anonymous-bits-field -> inline-enum-field-definition')
-@_formats('unconditional-bits-field -> unconditional-anonymous-bits-field')
-@_formats('unconditional-bits-field -> virtual-field')
-@_formats('unconditional-struct-field -> anonymous-bits-field-definition')
-@_formats('unconditional-struct-field -> field')
-@_formats('unconditional-struct-field -> inline-bits-field-definition')
-@_formats('unconditional-struct-field -> inline-enum-field-definition')
-@_formats('unconditional-struct-field -> inline-struct-field-definition')
-@_formats('unconditional-struct-field -> virtual-field')
+@_formats("snake-name -> snake-word")
+@_formats("snake-reference -> builtin-field-word")
+@_formats("snake-reference -> snake-word")
+@_formats("snake-word -> SnakeWord")
+@_formats("string-constant -> String")
+@_formats("type-definition -> bits")
+@_formats("type-definition -> enum")
+@_formats("type-definition -> external")
+@_formats("type-definition -> struct")
+@_formats("type-name -> type-word")
+@_formats("type-reference-tail -> type-word")
+@_formats("type-reference -> type-reference-tail")
+@_formats("type-size-specifier? -> type-size-specifier")
+@_formats("type-word -> CamelWord")
+@_formats("unconditional-anonymous-bits-field -> field")
+@_formats("unconditional-anonymous-bits-field -> inline-bits-field-definition")
+@_formats("unconditional-anonymous-bits-field -> inline-enum-field-definition")
+@_formats("unconditional-bits-field -> unconditional-anonymous-bits-field")
+@_formats("unconditional-bits-field -> virtual-field")
+@_formats("unconditional-struct-field -> anonymous-bits-field-definition")
+@_formats("unconditional-struct-field -> field")
+@_formats("unconditional-struct-field -> inline-bits-field-definition")
+@_formats("unconditional-struct-field -> inline-enum-field-definition")
+@_formats("unconditional-struct-field -> inline-struct-field-definition")
+@_formats("unconditional-struct-field -> virtual-field")
def _identity(x):
- return x
+ return x
-@_formats('argument-list -> expression comma-then-expression*')
-@_formats('times-expression -> negation-expression times-expression-right*')
-@_formats('type -> type-reference delimited-argument-list? type-size-specifier?'
- ' array-length-specifier*')
+@_formats("argument-list -> expression comma-then-expression*")
+@_formats("times-expression -> negation-expression times-expression-right*")
+@_formats(
+ "type -> type-reference delimited-argument-list? type-size-specifier?"
+ " array-length-specifier*"
+)
@_formats('array-length-specifier -> "[" expression "]"')
-@_formats('array-length-specifier* -> array-length-specifier'
- ' array-length-specifier*')
+@_formats(
+ "array-length-specifier* -> array-length-specifier"
+ " array-length-specifier*"
+)
@_formats('type-size-specifier -> ":" numeric-constant')
@_formats('attribute-context -> "(" snake-word ")"')
@_formats('constant-reference -> snake-reference "." constant-reference-tail')
@_formats('constant-reference-tail -> type-word "." constant-reference-tail')
@_formats('constant-reference-tail -> type-word "." snake-reference')
@_formats('type-reference-tail -> type-word "." type-reference-tail')
-@_formats('field-reference -> snake-reference field-reference-tail*')
+@_formats("field-reference -> snake-reference field-reference-tail*")
@_formats('abbreviation -> "(" snake-word ")"')
-@_formats('additive-expression-right -> additive-operator times-expression')
-@_formats('additive-expression-right* -> additive-expression-right'
- ' additive-expression-right*')
-@_formats('additive-expression -> times-expression additive-expression-right*')
+@_formats("additive-expression-right -> additive-operator times-expression")
+@_formats(
+ "additive-expression-right* -> additive-expression-right"
+ " additive-expression-right*"
+)
+@_formats("additive-expression -> times-expression additive-expression-right*")
@_formats('array-length-specifier -> "[" "]"')
@_formats('delimited-argument-list -> "(" argument-list ")"')
-@_formats('delimited-parameter-definition-list? ->'
- ' delimited-parameter-definition-list')
-@_formats('delimited-parameter-definition-list ->'
- ' "(" parameter-definition-list ")"')
-@_formats('parameter-definition-list -> parameter-definition'
- ' parameter-definition-list-tail*')
-@_formats('parameter-definition-list-tail* -> parameter-definition-list-tail'
- ' parameter-definition-list-tail*')
-@_formats('times-expression-right -> multiplicative-operator'
- ' negation-expression')
-@_formats('times-expression-right* -> times-expression-right'
- ' times-expression-right*')
+@_formats(
+ "delimited-parameter-definition-list? ->" " delimited-parameter-definition-list"
+)
+@_formats(
+ "delimited-parameter-definition-list ->" ' "(" parameter-definition-list ")"'
+)
+@_formats(
+ "parameter-definition-list -> parameter-definition"
+ " parameter-definition-list-tail*"
+)
+@_formats(
+ "parameter-definition-list-tail* -> parameter-definition-list-tail"
+ " parameter-definition-list-tail*"
+)
+@_formats(
+ "times-expression-right -> multiplicative-operator"
+ " negation-expression"
+)
+@_formats(
+ "times-expression-right* -> times-expression-right"
+ " times-expression-right*"
+)
@_formats('field-reference-tail -> "." snake-reference')
-@_formats('field-reference-tail* -> field-reference-tail field-reference-tail*')
-@_formats('negation-expression -> additive-operator bottom-expression')
+@_formats("field-reference-tail* -> field-reference-tail field-reference-tail*")
+@_formats("negation-expression -> additive-operator bottom-expression")
@_formats('type-reference -> snake-word "." type-reference-tail')
@_formats('bottom-expression -> "(" expression ")"')
@_formats('bottom-expression -> function-name "(" argument-list ")"')
-@_formats('comma-then-expression* -> comma-then-expression'
- ' comma-then-expression*')
-@_formats('or-expression-right* -> or-expression-right or-expression-right*')
-@_formats('less-expression-right-list -> equality-expression-right*'
- ' less-expression-right'
- ' equality-or-less-expression-right*')
-@_formats('or-expression-right+ -> or-expression-right or-expression-right*')
-@_formats('and-expression -> comparison-expression and-expression-right+')
-@_formats('comparison-expression -> additive-expression'
- ' greater-expression-right-list')
-@_formats('comparison-expression -> additive-expression'
- ' equality-expression-right+')
-@_formats('or-expression -> comparison-expression or-expression-right+')
-@_formats('equality-expression-right+ -> equality-expression-right'
- ' equality-expression-right*')
-@_formats('and-expression-right* -> and-expression-right and-expression-right*')
-@_formats('equality-or-greater-expression-right* ->'
- ' equality-or-greater-expression-right'
- ' equality-or-greater-expression-right*')
-@_formats('and-expression-right+ -> and-expression-right and-expression-right*')
-@_formats('equality-or-less-expression-right* ->'
- ' equality-or-less-expression-right'
- ' equality-or-less-expression-right*')
-@_formats('equality-expression-right* -> equality-expression-right'
- ' equality-expression-right*')
-@_formats('greater-expression-right-list ->'
- ' equality-expression-right* greater-expression-right'
- ' equality-or-greater-expression-right*')
-@_formats('comparison-expression -> additive-expression'
- ' less-expression-right-list')
+@_formats(
+ "comma-then-expression* -> comma-then-expression"
+ " comma-then-expression*"
+)
+@_formats("or-expression-right* -> or-expression-right or-expression-right*")
+@_formats(
+ "less-expression-right-list -> equality-expression-right*"
+ " less-expression-right"
+ " equality-or-less-expression-right*"
+)
+@_formats("or-expression-right+ -> or-expression-right or-expression-right*")
+@_formats("and-expression -> comparison-expression and-expression-right+")
+@_formats(
+ "comparison-expression -> additive-expression"
+ " greater-expression-right-list"
+)
+@_formats(
+ "comparison-expression -> additive-expression"
+ " equality-expression-right+"
+)
+@_formats("or-expression -> comparison-expression or-expression-right+")
+@_formats(
+ "equality-expression-right+ -> equality-expression-right"
+ " equality-expression-right*"
+)
+@_formats("and-expression-right* -> and-expression-right and-expression-right*")
+@_formats(
+ "equality-or-greater-expression-right* ->"
+ " equality-or-greater-expression-right"
+ " equality-or-greater-expression-right*"
+)
+@_formats("and-expression-right+ -> and-expression-right and-expression-right*")
+@_formats(
+ "equality-or-less-expression-right* ->"
+ " equality-or-less-expression-right"
+ " equality-or-less-expression-right*"
+)
+@_formats(
+ "equality-expression-right* -> equality-expression-right"
+ " equality-expression-right*"
+)
+@_formats(
+ "greater-expression-right-list ->"
+ " equality-expression-right* greater-expression-right"
+ " equality-or-greater-expression-right*"
+)
+@_formats(
+ "comparison-expression -> additive-expression"
+ " less-expression-right-list"
+)
def _concatenate(*elements):
- """Concatenates all arguments with no delimiters."""
- return ''.join(elements)
+ """Concatenates all arguments with no delimiters."""
+ return "".join(elements)
-@_formats('equality-expression-right -> equality-operator additive-expression')
-@_formats('less-expression-right -> less-operator additive-expression')
-@_formats('greater-expression-right -> greater-operator additive-expression')
-@_formats('or-expression-right -> or-operator comparison-expression')
-@_formats('and-expression-right -> and-operator comparison-expression')
+@_formats("equality-expression-right -> equality-operator additive-expression")
+@_formats("less-expression-right -> less-operator additive-expression")
+@_formats("greater-expression-right -> greater-operator additive-expression")
+@_formats("or-expression-right -> or-operator comparison-expression")
+@_formats("and-expression-right -> and-operator comparison-expression")
def _concatenate_with_prefix_spaces(*elements):
- return ''.join(' ' + element for element in elements if element)
+ return "".join(" " + element for element in elements if element)
-@_formats('attribute* -> attribute attribute*')
+@_formats("attribute* -> attribute attribute*")
@_formats('comma-then-expression -> "," expression')
-@_formats('comparison-expression -> additive-expression inequality-operator'
- ' additive-expression')
-@_formats('choice-expression -> logical-expression "?" logical-expression'
- ' ":" logical-expression')
+@_formats(
+ "comparison-expression -> additive-expression inequality-operator"
+ " additive-expression"
+)
+@_formats(
+ 'choice-expression -> logical-expression "?" logical-expression'
+ ' ":" logical-expression'
+)
@_formats('parameter-definition-list-tail -> "," parameter-definition')
def _concatenate_with_spaces(*elements):
- return _concatenate_with(' ', *elements)
+ return _concatenate_with(" ", *elements)
def _concatenate_with(joiner, *elements):
- return joiner.join(element for element in elements if element)
+ return joiner.join(element for element in elements if element)
-@_formats('attribute-line* -> attribute-line attribute-line*')
-@_formats('comment-line* -> comment-line comment-line*')
-@_formats('doc-line* -> doc-line doc-line*')
-@_formats('import-line* -> import-line import-line*')
+@_formats("attribute-line* -> attribute-line attribute-line*")
+@_formats("comment-line* -> comment-line comment-line*")
+@_formats("doc-line* -> doc-line doc-line*")
+@_formats("import-line* -> import-line import-line*")
def _concatenate_lists(head, tail):
- return head + tail
+ return head + tail
_check_productions()
diff --git a/compiler/front_end/format_emb_test.py b/compiler/front_end/format_emb_test.py
index e633f59..3d5331e 100644
--- a/compiler/front_end/format_emb_test.py
+++ b/compiler/front_end/format_emb_test.py
@@ -31,163 +31,180 @@
class SanityCheckerTest(unittest.TestCase):
- def test_text_does_not_tokenize(self):
- self.assertTrue(format_emb.sanity_check_format_result("-- doc", "~ bad"))
+ def test_text_does_not_tokenize(self):
+ self.assertTrue(format_emb.sanity_check_format_result("-- doc", "~ bad"))
- def test_original_text_does_not_tokenize(self):
- self.assertTrue(format_emb.sanity_check_format_result("~ bad", "-- doc"))
+ def test_original_text_does_not_tokenize(self):
+ self.assertTrue(format_emb.sanity_check_format_result("~ bad", "-- doc"))
- def test_text_matches(self):
- self.assertFalse(format_emb.sanity_check_format_result("-- doc", "-- doc"))
+ def test_text_matches(self):
+ self.assertFalse(format_emb.sanity_check_format_result("-- doc", "-- doc"))
- def test_text_has_extra_eols(self):
- self.assertFalse(
- format_emb.sanity_check_format_result("-- doc\n\n-- doc",
- "-- doc\n\n\n-- doc"))
+ def test_text_has_extra_eols(self):
+ self.assertFalse(
+ format_emb.sanity_check_format_result(
+ "-- doc\n\n-- doc", "-- doc\n\n\n-- doc"
+ )
+ )
- def test_text_has_fewer_eols(self):
- self.assertFalse(format_emb.sanity_check_format_result("-- doc\n\n-- doc",
- "-- doc\n-- doc"))
+ def test_text_has_fewer_eols(self):
+ self.assertFalse(
+ format_emb.sanity_check_format_result("-- doc\n\n-- doc", "-- doc\n-- doc")
+ )
- def test_original_text_has_leading_eols(self):
- self.assertFalse(format_emb.sanity_check_format_result("\n\n-- doc\n",
- "-- doc\n"))
+ def test_original_text_has_leading_eols(self):
+ self.assertFalse(
+ format_emb.sanity_check_format_result("\n\n-- doc\n", "-- doc\n")
+ )
- def test_original_text_has_extra_doc_whitespace(self):
- self.assertFalse(format_emb.sanity_check_format_result("-- doc \n",
- "-- doc\n"))
+ def test_original_text_has_extra_doc_whitespace(self):
+ self.assertFalse(
+ format_emb.sanity_check_format_result("-- doc \n", "-- doc\n")
+ )
- def test_comments_differ(self):
- self.assertTrue(format_emb.sanity_check_format_result("#c\n-- doc\n",
- "#d\n-- doc\n"))
+ def test_comments_differ(self):
+ self.assertTrue(
+ format_emb.sanity_check_format_result("#c\n-- doc\n", "#d\n-- doc\n")
+ )
- def test_comment_missing(self):
- self.assertTrue(format_emb.sanity_check_format_result("#c\n-- doc\n",
- "\n-- doc\n"))
+ def test_comment_missing(self):
+ self.assertTrue(
+ format_emb.sanity_check_format_result("#c\n-- doc\n", "\n-- doc\n")
+ )
- def test_comment_added(self):
- self.assertTrue(format_emb.sanity_check_format_result("\n-- doc\n",
- "#d\n-- doc\n"))
+ def test_comment_added(self):
+ self.assertTrue(
+ format_emb.sanity_check_format_result("\n-- doc\n", "#d\n-- doc\n")
+ )
- def test_token_text_differs(self):
- self.assertTrue(format_emb.sanity_check_format_result("-- doc\n",
- "-- bad doc\n"))
+ def test_token_text_differs(self):
+ self.assertTrue(
+ format_emb.sanity_check_format_result("-- doc\n", "-- bad doc\n")
+ )
- def test_token_type_differs(self):
- self.assertTrue(format_emb.sanity_check_format_result("-- doc\n",
- "abc\n"))
+ def test_token_type_differs(self):
+ self.assertTrue(format_emb.sanity_check_format_result("-- doc\n", "abc\n"))
- def test_eol_missing(self):
- self.assertTrue(format_emb.sanity_check_format_result("abc\n-- doc\n",
- "abc -- doc\n"))
+ def test_eol_missing(self):
+ self.assertTrue(
+ format_emb.sanity_check_format_result("abc\n-- doc\n", "abc -- doc\n")
+ )
class FormatEmbTest(unittest.TestCase):
- pass
+ pass
def _make_golden_file_tests():
- """Generates test cases from the golden files in the resource bundle."""
+ """Generates test cases from the golden files in the resource bundle."""
- package = "testdata.format"
- path_prefix = ""
+ package = "testdata.format"
+ path_prefix = ""
- def make_test_case(name, unformatted_text, expected_text, indent_width):
+ def make_test_case(name, unformatted_text, expected_text, indent_width):
- def test_case(self):
- self.maxDiff = 100000
- unformatted_tokens, errors = tokenizer.tokenize(unformatted_text, name)
- self.assertFalse(errors)
- parsed_unformatted = parser.parse_module(unformatted_tokens)
- self.assertFalse(parsed_unformatted.error)
- formatted_text = format_emb.format_emboss_parse_tree(
- parsed_unformatted.parse_tree,
- format_emb.Config(indent_width=indent_width))
- self.assertEqual(expected_text, formatted_text)
- annotated_text = format_emb.format_emboss_parse_tree(
- parsed_unformatted.parse_tree,
- format_emb.Config(indent_width=indent_width, show_line_types=True))
- self.assertEqual(expected_text, re.sub(r"^.*?\|", "", annotated_text,
- flags=re.MULTILINE))
- self.assertFalse(re.search("^[^|]+$", annotated_text, flags=re.MULTILINE))
+ def test_case(self):
+ self.maxDiff = 100000
+ unformatted_tokens, errors = tokenizer.tokenize(unformatted_text, name)
+ self.assertFalse(errors)
+ parsed_unformatted = parser.parse_module(unformatted_tokens)
+ self.assertFalse(parsed_unformatted.error)
+ formatted_text = format_emb.format_emboss_parse_tree(
+ parsed_unformatted.parse_tree,
+ format_emb.Config(indent_width=indent_width),
+ )
+ self.assertEqual(expected_text, formatted_text)
+ annotated_text = format_emb.format_emboss_parse_tree(
+ parsed_unformatted.parse_tree,
+ format_emb.Config(indent_width=indent_width, show_line_types=True),
+ )
+ self.assertEqual(
+ expected_text, re.sub(r"^.*?\|", "", annotated_text, flags=re.MULTILINE)
+ )
+ self.assertFalse(re.search("^[^|]+$", annotated_text, flags=re.MULTILINE))
- return test_case
+ return test_case
- all_unformatted_texts = []
+ all_unformatted_texts = []
- for filename in (
- "abbreviations",
- "anonymous_bits_formatting",
- "arithmetic_expressions",
- "array_length",
- "attributes",
- "choice_expression",
- "comparison_expressions",
- "conditional_field_formatting",
- "conditional_inline_bits_formatting",
- "dotted_names",
- "empty",
- "enum_value_attributes",
- "enum_value_bodies",
- "enum_values_aligned",
- "equality_expressions",
- "external",
- "extra_newlines",
- "fields_aligned",
- "functions",
- "header_and_type",
- "indent",
- "inline_attributes_get_a_column",
- "inline_bits",
- "inline_documentation_gets_a_column",
- "inline_enum",
- "inline_struct",
- "lines_not_spaced_out_with_excess_trailing_noise_lines",
- "lines_not_spaced_out_with_not_enough_noise_lines",
- "lines_spaced_out_with_noise_lines",
- "logical_expressions",
- "multiline_ifs",
- "multiple_header_sections",
- "nested_types_are_columnized_independently",
- "one_type",
- "parameterized_struct",
- "sanity_check",
- "spacing_between_types",
- "trailing_spaces",
- "virtual_fields"):
- for suffix, width in ((".emb.formatted", 2),
- (".emb.formatted_indent_4", 4)):
- unformatted_name = path_prefix + filename + ".emb"
- expected_name = path_prefix + filename + suffix
- unformatted_text = pkgutil.get_data(package,
- unformatted_name).decode("utf-8")
- expected_text = pkgutil.get_data(package, expected_name).decode("utf-8")
- setattr(FormatEmbTest, "test {} indent {}".format(filename, width),
- make_test_case(filename, unformatted_text, expected_text, width))
+ for filename in (
+ "abbreviations",
+ "anonymous_bits_formatting",
+ "arithmetic_expressions",
+ "array_length",
+ "attributes",
+ "choice_expression",
+ "comparison_expressions",
+ "conditional_field_formatting",
+ "conditional_inline_bits_formatting",
+ "dotted_names",
+ "empty",
+ "enum_value_attributes",
+ "enum_value_bodies",
+ "enum_values_aligned",
+ "equality_expressions",
+ "external",
+ "extra_newlines",
+ "fields_aligned",
+ "functions",
+ "header_and_type",
+ "indent",
+ "inline_attributes_get_a_column",
+ "inline_bits",
+ "inline_documentation_gets_a_column",
+ "inline_enum",
+ "inline_struct",
+ "lines_not_spaced_out_with_excess_trailing_noise_lines",
+ "lines_not_spaced_out_with_not_enough_noise_lines",
+ "lines_spaced_out_with_noise_lines",
+ "logical_expressions",
+ "multiline_ifs",
+ "multiple_header_sections",
+ "nested_types_are_columnized_independently",
+ "one_type",
+ "parameterized_struct",
+ "sanity_check",
+ "spacing_between_types",
+ "trailing_spaces",
+ "virtual_fields",
+ ):
+ for suffix, width in ((".emb.formatted", 2), (".emb.formatted_indent_4", 4)):
+ unformatted_name = path_prefix + filename + ".emb"
+ expected_name = path_prefix + filename + suffix
+ unformatted_text = pkgutil.get_data(package, unformatted_name).decode(
+ "utf-8"
+ )
+ expected_text = pkgutil.get_data(package, expected_name).decode("utf-8")
+ setattr(
+ FormatEmbTest,
+ "test {} indent {}".format(filename, width),
+ make_test_case(filename, unformatted_text, expected_text, width),
+ )
- all_unformatted_texts.append(unformatted_text)
+ all_unformatted_texts.append(unformatted_text)
- def test_all_productions_used(self):
- used_productions = set()
- for unformatted_text in all_unformatted_texts:
- unformatted_tokens, errors = tokenizer.tokenize(unformatted_text, "")
- self.assertFalse(errors)
- parsed_unformatted = parser.parse_module(unformatted_tokens)
- self.assertFalse(parsed_unformatted.error)
- format_emb.format_emboss_parse_tree(parsed_unformatted.parse_tree,
- format_emb.Config(), used_productions)
- unused_productions = set(module_ir.PRODUCTIONS) - used_productions
- if unused_productions:
- print("Used production total:", len(used_productions), file=sys.stderr)
- for production in unused_productions:
- print("Unused production:", str(production), file=sys.stderr)
- print("Total:", len(unused_productions), file=sys.stderr)
- self.assertEqual(set(module_ir.PRODUCTIONS), used_productions)
+ def test_all_productions_used(self):
+ used_productions = set()
+ for unformatted_text in all_unformatted_texts:
+ unformatted_tokens, errors = tokenizer.tokenize(unformatted_text, "")
+ self.assertFalse(errors)
+ parsed_unformatted = parser.parse_module(unformatted_tokens)
+ self.assertFalse(parsed_unformatted.error)
+ format_emb.format_emboss_parse_tree(
+ parsed_unformatted.parse_tree, format_emb.Config(), used_productions
+ )
+ unused_productions = set(module_ir.PRODUCTIONS) - used_productions
+ if unused_productions:
+ print("Used production total:", len(used_productions), file=sys.stderr)
+ for production in unused_productions:
+ print("Unused production:", str(production), file=sys.stderr)
+ print("Total:", len(unused_productions), file=sys.stderr)
+ self.assertEqual(set(module_ir.PRODUCTIONS), used_productions)
- FormatEmbTest.testAllProductionsUsed = test_all_productions_used
+ FormatEmbTest.testAllProductionsUsed = test_all_productions_used
_make_golden_file_tests()
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/generate_grammar_md.py b/compiler/front_end/generate_grammar_md.py
index 6cb2c58..53c9e50 100644
--- a/compiler/front_end/generate_grammar_md.py
+++ b/compiler/front_end/generate_grammar_md.py
@@ -65,172 +65,179 @@
def _sort_productions(productions, start_symbol):
- """Sorts the given productions in a human-friendly order."""
- productions_by_lhs = {}
- for p in productions:
- if p.lhs not in productions_by_lhs:
- productions_by_lhs[p.lhs] = set()
- productions_by_lhs[p.lhs].add(p)
+ """Sorts the given productions in a human-friendly order."""
+ productions_by_lhs = {}
+ for p in productions:
+ if p.lhs not in productions_by_lhs:
+ productions_by_lhs[p.lhs] = set()
+ productions_by_lhs[p.lhs].add(p)
- queue = [start_symbol]
- previously_queued_symbols = set(queue)
- main_production_list = []
- # This sorts productions depth-first. I'm not sure if it is better to sort
- # them breadth-first or depth-first, or with some hybrid.
- while queue:
- symbol = queue.pop(-1)
- if symbol not in productions_by_lhs:
- continue
- for production in sorted(productions_by_lhs[symbol]):
- main_production_list.append(production)
- for symbol in production.rhs:
- # Skip boilerplate productions for now, but include their base
- # production.
- if symbol and symbol[-1] in "*+?":
- symbol = symbol[0:-1]
- if symbol not in previously_queued_symbols:
- queue.append(symbol)
- previously_queued_symbols.add(symbol)
+ queue = [start_symbol]
+ previously_queued_symbols = set(queue)
+ main_production_list = []
+ # This sorts productions depth-first. I'm not sure if it is better to sort
+ # them breadth-first or depth-first, or with some hybrid.
+ while queue:
+ symbol = queue.pop(-1)
+ if symbol not in productions_by_lhs:
+ continue
+ for production in sorted(productions_by_lhs[symbol]):
+ main_production_list.append(production)
+ for symbol in production.rhs:
+ # Skip boilerplate productions for now, but include their base
+ # production.
+ if symbol and symbol[-1] in "*+?":
+ symbol = symbol[0:-1]
+ if symbol not in previously_queued_symbols:
+ queue.append(symbol)
+ previously_queued_symbols.add(symbol)
- # It's not particularly important to put boilerplate productions in any
- # particular order.
- boilerplate_production_list = sorted(
- set(productions) - set(main_production_list))
- for production in boilerplate_production_list:
- assert production.lhs[-1] in "*+?", "Found orphaned production {}".format(
- production.lhs)
- assert set(productions) == set(
- main_production_list + boilerplate_production_list)
- assert len(productions) == len(main_production_list) + len(
- boilerplate_production_list)
- return main_production_list, boilerplate_production_list
+ # It's not particularly important to put boilerplate productions in any
+ # particular order.
+ boilerplate_production_list = sorted(set(productions) - set(main_production_list))
+ for production in boilerplate_production_list:
+ assert production.lhs[-1] in "*+?", "Found orphaned production {}".format(
+ production.lhs
+ )
+ assert set(productions) == set(main_production_list + boilerplate_production_list)
+ assert len(productions) == len(main_production_list) + len(
+ boilerplate_production_list
+ )
+ return main_production_list, boilerplate_production_list
def _word_wrap_at_column(words, width):
- """Wraps words to the specified width, and returns a list of wrapped lines."""
- result = []
- in_progress = []
- for word in words:
- if len(" ".join(in_progress + [word])) > width:
- result.append(" ".join(in_progress))
- assert len(result[-1]) <= width
- in_progress = []
- in_progress.append(word)
- result.append(" ".join(in_progress))
- assert len(result[-1]) <= width
- return result
+ """Wraps words to the specified width, and returns a list of wrapped lines."""
+ result = []
+ in_progress = []
+ for word in words:
+ if len(" ".join(in_progress + [word])) > width:
+ result.append(" ".join(in_progress))
+ assert len(result[-1]) <= width
+ in_progress = []
+ in_progress.append(word)
+ result.append(" ".join(in_progress))
+ assert len(result[-1]) <= width
+ return result
def _format_productions(productions):
- """Formats a list of productions for inclusion in a Markdown document."""
- max_lhs_len = max([len(production.lhs) for production in productions])
+ """Formats a list of productions for inclusion in a Markdown document."""
+ max_lhs_len = max([len(production.lhs) for production in productions])
- # TODO(bolms): This highlighting is close for now, but not actually right.
- result = ["```shell\n"]
- last_lhs = None
- for production in productions:
- if last_lhs == production.lhs:
- lhs = ""
- delimiter = " |"
- else:
- lhs = production.lhs
- delimiter = "->"
- leader = "{lhs:{width}} {delimiter}".format(
- lhs=lhs,
- width=max_lhs_len,
- delimiter=delimiter)
- for rhs_block in _word_wrap_at_column(
- production.rhs or ["<empty>"], _MAX_OUTPUT_WIDTH - len(leader)):
- result.append("{leader} {rhs}\n".format(leader=leader, rhs=rhs_block))
- leader = " " * len(leader)
- last_lhs = production.lhs
- result.append("```\n")
- return "".join(result)
+ # TODO(bolms): This highlighting is close for now, but not actually right.
+ result = ["```shell\n"]
+ last_lhs = None
+ for production in productions:
+ if last_lhs == production.lhs:
+ lhs = ""
+ delimiter = " |"
+ else:
+ lhs = production.lhs
+ delimiter = "->"
+ leader = "{lhs:{width}} {delimiter}".format(
+ lhs=lhs, width=max_lhs_len, delimiter=delimiter
+ )
+ for rhs_block in _word_wrap_at_column(
+ production.rhs or ["<empty>"], _MAX_OUTPUT_WIDTH - len(leader)
+ ):
+ result.append("{leader} {rhs}\n".format(leader=leader, rhs=rhs_block))
+ leader = " " * len(leader)
+ last_lhs = production.lhs
+ result.append("```\n")
+ return "".join(result)
def _normalize_literal_patterns(literals):
- """Normalizes a list of strings to a list of (regex, symbol) pairs."""
- return [(re.sub(r"(\W)", r"\\\1", literal), '"' + literal + '"')
- for literal in literals]
+ """Normalizes a list of strings to a list of (regex, symbol) pairs."""
+ return [
+ (re.sub(r"(\W)", r"\\\1", literal), '"' + literal + '"') for literal in literals
+ ]
def _normalize_regex_patterns(regexes):
- """Normalizes a list of tokenizer regexes to a list of (regex, symbol)."""
- # g3doc breaks up patterns containing '|' when they are inserted into a table,
- # unless they're preceded by '\'. Note that other special characters,
- # including '\', should *not* be escaped with '\'.
- return [(re.sub(r"\|", r"\\|", r.regex.pattern), r.symbol) for r in regexes]
+ """Normalizes a list of tokenizer regexes to a list of (regex, symbol)."""
+ # g3doc breaks up patterns containing '|' when they are inserted into a table,
+ # unless they're preceded by '\'. Note that other special characters,
+ # including '\', should *not* be escaped with '\'.
+ return [(re.sub(r"\|", r"\\|", r.regex.pattern), r.symbol) for r in regexes]
def _normalize_reserved_word_list(reserved_words):
- """Returns words that would be allowed as names if they were not reserved."""
- interesting_reserved_words = []
- for word in reserved_words:
- tokens, errors = tokenizer.tokenize(word, "")
- assert tokens and not errors, "Failed to tokenize " + word
- if tokens[0].symbol in ["SnakeWord", "CamelWord", "ShoutyWord"]:
- interesting_reserved_words.append(word)
- return sorted(interesting_reserved_words)
+ """Returns words that would be allowed as names if they were not reserved."""
+ interesting_reserved_words = []
+ for word in reserved_words:
+ tokens, errors = tokenizer.tokenize(word, "")
+ assert tokens and not errors, "Failed to tokenize " + word
+ if tokens[0].symbol in ["SnakeWord", "CamelWord", "ShoutyWord"]:
+ interesting_reserved_words.append(word)
+ return sorted(interesting_reserved_words)
def _format_token_rules(token_rules):
- """Formats a list of (pattern, symbol) pairs as a table."""
- pattern_width = max([len(rule[0]) for rule in token_rules])
- pattern_width += 2 # For the `` characters.
- result = ["{pat_header:{width}} | Symbol\n"
- "{empty:-<{width}} | {empty:-<30}\n".format(pat_header="Pattern",
- width=pattern_width,
- empty="")]
- for rule in token_rules:
- if rule[1]:
- symbol_name = "`" + rule[1] + "`"
- else:
- symbol_name = "*no symbol emitted*"
- result.append(
- "{pattern:{width}} | {symbol}\n".format(pattern="`" + rule[0] + "`",
- width=pattern_width,
- symbol=symbol_name))
- return "".join(result)
+ """Formats a list of (pattern, symbol) pairs as a table."""
+ pattern_width = max([len(rule[0]) for rule in token_rules])
+ pattern_width += 2 # For the `` characters.
+ result = [
+ "{pat_header:{width}} | Symbol\n"
+ "{empty:-<{width}} | {empty:-<30}\n".format(
+ pat_header="Pattern", width=pattern_width, empty=""
+ )
+ ]
+ for rule in token_rules:
+ if rule[1]:
+ symbol_name = "`" + rule[1] + "`"
+ else:
+ symbol_name = "*no symbol emitted*"
+ result.append(
+ "{pattern:{width}} | {symbol}\n".format(
+ pattern="`" + rule[0] + "`", width=pattern_width, symbol=symbol_name
+ )
+ )
+ return "".join(result)
def _format_keyword_list(reserved_words):
- """formats a list of reserved words."""
- lines = []
- current_line = ""
- for word in reserved_words:
- if len(current_line) + len(word) + 2 > 80:
- lines.append(current_line)
- current_line = ""
- current_line += "`{}` ".format(word)
- return "".join([line[:-1] + "\n" for line in lines])
+ """formats a list of reserved words."""
+ lines = []
+ current_line = ""
+ for word in reserved_words:
+ if len(current_line) + len(word) + 2 > 80:
+ lines.append(current_line)
+ current_line = ""
+ current_line += "`{}` ".format(word)
+ return "".join([line[:-1] + "\n" for line in lines])
def generate_grammar_md():
- """Generates up-to-date text for grammar.md."""
- main_productions, boilerplate_productions = _sort_productions(
- module_ir.PRODUCTIONS, module_ir.START_SYMBOL)
- result = [_HEADER, _format_productions(main_productions),
- _BOILERPLATE_PRODUCTION_HEADER,
- _format_productions(boilerplate_productions)]
+ """Generates up-to-date text for grammar.md."""
+ main_productions, boilerplate_productions = _sort_productions(
+ module_ir.PRODUCTIONS, module_ir.START_SYMBOL
+ )
+ result = [
+ _HEADER,
+ _format_productions(main_productions),
+ _BOILERPLATE_PRODUCTION_HEADER,
+ _format_productions(boilerplate_productions),
+ ]
- main_tokens = _normalize_literal_patterns(tokenizer.LITERAL_TOKEN_PATTERNS)
- main_tokens += _normalize_regex_patterns(tokenizer.REGEX_TOKEN_PATTERNS)
- result.append(_TOKENIZER_RULE_HEADER)
- result.append(_format_token_rules(main_tokens))
+ main_tokens = _normalize_literal_patterns(tokenizer.LITERAL_TOKEN_PATTERNS)
+ main_tokens += _normalize_regex_patterns(tokenizer.REGEX_TOKEN_PATTERNS)
+ result.append(_TOKENIZER_RULE_HEADER)
+ result.append(_format_token_rules(main_tokens))
- reserved_words = _normalize_reserved_word_list(
- constraints.get_reserved_word_list())
- result.append(_KEYWORDS_HEADER.format(len(reserved_words)))
- result.append(_format_keyword_list(reserved_words))
+ reserved_words = _normalize_reserved_word_list(constraints.get_reserved_word_list())
+ result.append(_KEYWORDS_HEADER.format(len(reserved_words)))
+ result.append(_format_keyword_list(reserved_words))
- return "".join(result)
+ return "".join(result)
def main(argv):
- del argv # Unused.
- print(generate_grammar_md(), end="")
- return 0
+ del argv # Unused.
+ print(generate_grammar_md(), end="")
+ return 0
if __name__ == "__main__":
- sys.exit(main(sys.argv))
+ sys.exit(main(sys.argv))
diff --git a/compiler/front_end/glue.py b/compiler/front_end/glue.py
index a1e1a5b..2744086 100644
--- a/compiler/front_end/glue.py
+++ b/compiler/front_end/glue.py
@@ -38,330 +38,347 @@
from compiler.util import parser_types
from compiler.util import resources
-_IrDebugInfo = collections.namedtuple("IrDebugInfo", ["ir", "debug_info",
- "errors"])
+_IrDebugInfo = collections.namedtuple("IrDebugInfo", ["ir", "debug_info", "errors"])
class DebugInfo(object):
- """Debug information about Emboss parsing."""
- __slots__ = ("modules")
+ """Debug information about Emboss parsing."""
- def __init__(self):
- self.modules = {}
+ __slots__ = "modules"
- def __eq__(self, other):
- return self.modules == other.modules
+ def __init__(self):
+ self.modules = {}
- def __ne__(self, other):
- return not self == other
+ def __eq__(self, other):
+ return self.modules == other.modules
+
+ def __ne__(self, other):
+ return not self == other
class ModuleDebugInfo(object):
- """Debug information about the parse of a single file.
+ """Debug information about the parse of a single file.
- Attributes:
- file_name: The name of the file from which this module came.
- tokens: The tokenization of this module's source text.
- parse_tree: The raw parse tree for this module.
- ir: The intermediate representation of this module, before additional
- processing such as symbol resolution.
- used_productions: The set of grammar productions used when parsing this
- module.
- source_code: The source text of the module.
- """
- __slots__ = ("file_name", "tokens", "parse_tree", "ir", "used_productions",
- "source_code")
+ Attributes:
+ file_name: The name of the file from which this module came.
+ tokens: The tokenization of this module's source text.
+ parse_tree: The raw parse tree for this module.
+ ir: The intermediate representation of this module, before additional
+ processing such as symbol resolution.
+ used_productions: The set of grammar productions used when parsing this
+ module.
+ source_code: The source text of the module.
+ """
- def __init__(self, file_name):
- self.file_name = file_name
- self.tokens = None
- self.parse_tree = None
- self.ir = None
- self.used_productions = None
- self.source_code = None
+ __slots__ = (
+ "file_name",
+ "tokens",
+ "parse_tree",
+ "ir",
+ "used_productions",
+ "source_code",
+ )
- def __eq__(self, other):
- return (self.file_name == other.file_name and self.tokens == other.tokens
- and self.parse_tree == other.parse_tree and self.ir == other.ir and
- self.used_productions == other.used_productions and
- self.source_code == other.source_code)
+ def __init__(self, file_name):
+ self.file_name = file_name
+ self.tokens = None
+ self.parse_tree = None
+ self.ir = None
+ self.used_productions = None
+ self.source_code = None
- def __ne__(self, other):
- return not self == other
+ def __eq__(self, other):
+ return (
+ self.file_name == other.file_name
+ and self.tokens == other.tokens
+ and self.parse_tree == other.parse_tree
+ and self.ir == other.ir
+ and self.used_productions == other.used_productions
+ and self.source_code == other.source_code
+ )
- def format_tokenization(self):
- """Renders self.tokens in a human-readable format."""
- return "\n".join([str(token) for token in self.tokens])
+ def __ne__(self, other):
+ return not self == other
- def format_parse_tree(self, parse_tree=None, indent=""):
- """Renders self.parse_tree in a human-readable format."""
- if parse_tree is None:
- parse_tree = self.parse_tree
- result = []
- if isinstance(parse_tree, lr1.Reduction):
- result.append(indent + parse_tree.symbol)
- if parse_tree.children:
- result.append(":\n")
- for child in parse_tree.children:
- result.append(self.format_parse_tree(child, indent + " "))
- else:
- result.append("\n")
- else:
- result.append("{}{}\n".format(indent, parse_tree))
- return "".join(result)
+ def format_tokenization(self):
+ """Renders self.tokens in a human-readable format."""
+ return "\n".join([str(token) for token in self.tokens])
- def format_module_ir(self):
- """Renders self.ir in a human-readable format."""
- return ir_data_utils.IrDataSerializer(self.ir).to_json(indent=2)
+ def format_parse_tree(self, parse_tree=None, indent=""):
+ """Renders self.parse_tree in a human-readable format."""
+ if parse_tree is None:
+ parse_tree = self.parse_tree
+ result = []
+ if isinstance(parse_tree, lr1.Reduction):
+ result.append(indent + parse_tree.symbol)
+ if parse_tree.children:
+ result.append(":\n")
+ for child in parse_tree.children:
+ result.append(self.format_parse_tree(child, indent + " "))
+ else:
+ result.append("\n")
+ else:
+ result.append("{}{}\n".format(indent, parse_tree))
+ return "".join(result)
+
+ def format_module_ir(self):
+ """Renders self.ir in a human-readable format."""
+ return ir_data_utils.IrDataSerializer(self.ir).to_json(indent=2)
def format_production_set(productions):
- """Renders a set of productions in a human-readable format."""
- return "\n".join([str(production) for production in sorted(productions)])
+ """Renders a set of productions in a human-readable format."""
+ return "\n".join([str(production) for production in sorted(productions)])
_cached_modules = {}
def parse_module_text(source_code, file_name):
- """Parses the text of a module, returning a module-level IR.
+ """Parses the text of a module, returning a module-level IR.
- Arguments:
- source_code: The text of the module to parse.
- file_name: The name of the module's source file (will be included in the
- resulting IR).
+ Arguments:
+ source_code: The text of the module to parse.
+ file_name: The name of the module's source file (will be included in the
+ resulting IR).
- Returns:
- A module-level intermediate representation (IR), prior to import and symbol
- resolution, and a corresponding ModuleDebugInfo, for debugging the parser.
+ Returns:
+ A module-level intermediate representation (IR), prior to import and symbol
+ resolution, and a corresponding ModuleDebugInfo, for debugging the parser.
- Raises:
- FrontEndFailure: An error occurred while parsing the module. str(error)
- will give a human-readable error message.
- """
- # This is strictly an optimization to speed up tests, mostly by avoiding the
- # need to re-parse the prelude for every test .emb.
- if (source_code, file_name) in _cached_modules:
- debug_info = _cached_modules[source_code, file_name]
- ir = ir_data_utils.copy(debug_info.ir)
- else:
- debug_info = ModuleDebugInfo(file_name)
- debug_info.source_code = source_code
- tokens, errors = tokenizer.tokenize(source_code, file_name)
- if errors:
- return _IrDebugInfo(None, debug_info, errors)
- debug_info.tokens = tokens
- parse_result = parser.parse_module(tokens)
- if parse_result.error:
- return _IrDebugInfo(
- None,
- debug_info,
- [error.make_error_from_parse_error(file_name, parse_result.error)])
- debug_info.parse_tree = parse_result.parse_tree
- used_productions = set()
- ir = module_ir.build_ir(parse_result.parse_tree, used_productions)
- ir.source_text = source_code
- debug_info.used_productions = used_productions
- debug_info.ir = ir_data_utils.copy(ir)
- _cached_modules[source_code, file_name] = debug_info
- ir.source_file_name = file_name
- return _IrDebugInfo(ir, debug_info, [])
+ Raises:
+ FrontEndFailure: An error occurred while parsing the module. str(error)
+ will give a human-readable error message.
+ """
+ # This is strictly an optimization to speed up tests, mostly by avoiding the
+ # need to re-parse the prelude for every test .emb.
+ if (source_code, file_name) in _cached_modules:
+ debug_info = _cached_modules[source_code, file_name]
+ ir = ir_data_utils.copy(debug_info.ir)
+ else:
+ debug_info = ModuleDebugInfo(file_name)
+ debug_info.source_code = source_code
+ tokens, errors = tokenizer.tokenize(source_code, file_name)
+ if errors:
+ return _IrDebugInfo(None, debug_info, errors)
+ debug_info.tokens = tokens
+ parse_result = parser.parse_module(tokens)
+ if parse_result.error:
+ return _IrDebugInfo(
+ None,
+ debug_info,
+ [error.make_error_from_parse_error(file_name, parse_result.error)],
+ )
+ debug_info.parse_tree = parse_result.parse_tree
+ used_productions = set()
+ ir = module_ir.build_ir(parse_result.parse_tree, used_productions)
+ ir.source_text = source_code
+ debug_info.used_productions = used_productions
+ debug_info.ir = ir_data_utils.copy(ir)
+ _cached_modules[source_code, file_name] = debug_info
+ ir.source_file_name = file_name
+ return _IrDebugInfo(ir, debug_info, [])
def parse_module(file_name, file_reader):
- """Parses a module, returning a module-level IR.
+ """Parses a module, returning a module-level IR.
- Arguments:
- file_name: The name of the module's source file.
- file_reader: A callable that returns either:
- (file_contents, None) or
- (None, list_of_error_detail_strings)
+ Arguments:
+ file_name: The name of the module's source file.
+ file_reader: A callable that returns either:
+ (file_contents, None) or
+ (None, list_of_error_detail_strings)
- Returns:
- (ir, debug_info, errors), where ir is a module-level intermediate
- representation (IR), debug_info is a ModuleDebugInfo containing the
- tokenization, parse tree, and original source text of all modules, and
- errors is a list of tokenization or parse errors. If errors is not an empty
- list, ir will be None.
+ Returns:
+ (ir, debug_info, errors), where ir is a module-level intermediate
+ representation (IR), debug_info is a ModuleDebugInfo containing the
+ tokenization, parse tree, and original source text of all modules, and
+ errors is a list of tokenization or parse errors. If errors is not an empty
+ list, ir will be None.
- Raises:
- FrontEndFailure: An error occurred while reading or parsing the module.
- str(error) will give a human-readable error message.
- """
- source_code, errors = file_reader(file_name)
- if errors:
- location = parser_types.make_location((1, 1), (1, 1))
- return None, None, [
- [error.error(file_name, location, "Unable to read file.")] +
- [error.note(file_name, location, e) for e in errors]
- ]
- return parse_module_text(source_code, file_name)
+ Raises:
+ FrontEndFailure: An error occurred while reading or parsing the module.
+ str(error) will give a human-readable error message.
+ """
+ source_code, errors = file_reader(file_name)
+ if errors:
+ location = parser_types.make_location((1, 1), (1, 1))
+ return (
+ None,
+ None,
+ [
+ [error.error(file_name, location, "Unable to read file.")]
+ + [error.note(file_name, location, e) for e in errors]
+ ],
+ )
+ return parse_module_text(source_code, file_name)
def get_prelude():
- """Returns the module IR and debug info of the Emboss Prelude."""
- return parse_module_text(
- resources.load("compiler.front_end", "prelude.emb"), "")
+ """Returns the module IR and debug info of the Emboss Prelude."""
+ return parse_module_text(resources.load("compiler.front_end", "prelude.emb"), "")
def parse_emboss_file(file_name, file_reader, stop_before_step=None):
- """Fully parses an .emb, and returns an IR suitable for passing to a back end.
+ """Fully parses an .emb, and returns an IR suitable for passing to a back end.
- parse_emboss_file is a convenience function which calls only_parse_emboss_file
- and process_ir.
+ parse_emboss_file is a convenience function which calls only_parse_emboss_file
+ and process_ir.
- Arguments:
- file_name: The name of the module's source file.
- file_reader: A callable that returns the contents of files, or raises
- IOError.
- stop_before_step: If set, parse_emboss_file will stop normalizing the IR
- just before the specified step. This parameter should be None for
- non-test code.
+ Arguments:
+ file_name: The name of the module's source file.
+ file_reader: A callable that returns the contents of files, or raises
+ IOError.
+ stop_before_step: If set, parse_emboss_file will stop normalizing the IR
+ just before the specified step. This parameter should be None for
+ non-test code.
- Returns:
- (ir, debug_info, errors), where ir is a complete IR, ready for consumption
- by an Emboss back end, debug_info is a DebugInfo containing the
- tokenization, parse tree, and original source text of all modules, and
- errors is a list of tokenization or parse errors. If errors is not an empty
- list, ir will be None.
- """
- ir, debug_info, errors = only_parse_emboss_file(file_name, file_reader)
- if errors:
- return _IrDebugInfo(None, debug_info, errors)
- ir, errors = process_ir(ir, stop_before_step)
- if errors:
- return _IrDebugInfo(None, debug_info, errors)
- return _IrDebugInfo(ir, debug_info, errors)
+ Returns:
+ (ir, debug_info, errors), where ir is a complete IR, ready for consumption
+ by an Emboss back end, debug_info is a DebugInfo containing the
+ tokenization, parse tree, and original source text of all modules, and
+ errors is a list of tokenization or parse errors. If errors is not an empty
+ list, ir will be None.
+ """
+ ir, debug_info, errors = only_parse_emboss_file(file_name, file_reader)
+ if errors:
+ return _IrDebugInfo(None, debug_info, errors)
+ ir, errors = process_ir(ir, stop_before_step)
+ if errors:
+ return _IrDebugInfo(None, debug_info, errors)
+ return _IrDebugInfo(ir, debug_info, errors)
def only_parse_emboss_file(file_name, file_reader):
- """Parses an .emb, and returns an IR suitable for process_ir.
+ """Parses an .emb, and returns an IR suitable for process_ir.
- only_parse_emboss_file parses the given file and all of its transitive
- imports, and returns a first-stage intermediate representation, which can be
- passed to process_ir.
+ only_parse_emboss_file parses the given file and all of its transitive
+ imports, and returns a first-stage intermediate representation, which can be
+ passed to process_ir.
- Arguments:
- file_name: The name of the module's source file.
- file_reader: A callable that returns the contents of files, or raises
- IOError.
+ Arguments:
+ file_name: The name of the module's source file.
+ file_reader: A callable that returns the contents of files, or raises
+ IOError.
- Returns:
- (ir, debug_info, errors), where ir is an intermediate representation (IR),
- debug_info is a DebugInfo containing the tokenization, parse tree, and
- original source text of all modules, and errors is a list of tokenization or
- parse errors. If errors is not an empty list, ir will be None.
- """
- file_queue = [file_name]
- files = {file_name}
- debug_info = DebugInfo()
- ir = ir_data.EmbossIr(module=[])
- while file_queue:
- file_to_parse = file_queue[0]
- del file_queue[0]
- if file_to_parse:
- module, module_debug_info, errors = parse_module(file_to_parse,
- file_reader)
- else:
- module, module_debug_info, errors = get_prelude()
- if module_debug_info:
- debug_info.modules[file_to_parse] = module_debug_info
- if errors:
- return _IrDebugInfo(None, debug_info, errors)
- ir.module.extend([module]) # Proto supports extend but not append here.
- for import_ in module.foreign_import:
- if import_.file_name.text not in files:
- file_queue.append(import_.file_name.text)
- files.add(import_.file_name.text)
- return _IrDebugInfo(ir, debug_info, [])
+ Returns:
+ (ir, debug_info, errors), where ir is an intermediate representation (IR),
+ debug_info is a DebugInfo containing the tokenization, parse tree, and
+ original source text of all modules, and errors is a list of tokenization or
+ parse errors. If errors is not an empty list, ir will be None.
+ """
+ file_queue = [file_name]
+ files = {file_name}
+ debug_info = DebugInfo()
+ ir = ir_data.EmbossIr(module=[])
+ while file_queue:
+ file_to_parse = file_queue[0]
+ del file_queue[0]
+ if file_to_parse:
+ module, module_debug_info, errors = parse_module(file_to_parse, file_reader)
+ else:
+ module, module_debug_info, errors = get_prelude()
+ if module_debug_info:
+ debug_info.modules[file_to_parse] = module_debug_info
+ if errors:
+ return _IrDebugInfo(None, debug_info, errors)
+ ir.module.extend([module]) # Proto supports extend but not append here.
+ for import_ in module.foreign_import:
+ if import_.file_name.text not in files:
+ file_queue.append(import_.file_name.text)
+ files.add(import_.file_name.text)
+ return _IrDebugInfo(ir, debug_info, [])
def process_ir(ir, stop_before_step):
- """Turns a first-stage IR into a fully-processed IR.
+ """Turns a first-stage IR into a fully-processed IR.
- process_ir performs all of the semantic processing steps on `ir`: resolving
- symbols, checking dependencies, adding type annotations, normalizing
- attributes, etc. process_ir is generally meant to be called with the result
- of parse_emboss_file(), but in theory could be called with a first-stage
- intermediate representation (IR) from another source.
+ process_ir performs all of the semantic processing steps on `ir`: resolving
+ symbols, checking dependencies, adding type annotations, normalizing
+ attributes, etc. process_ir is generally meant to be called with the result
+ of parse_emboss_file(), but in theory could be called with a first-stage
+ intermediate representation (IR) from another source.
- Arguments:
- ir: The IR to process. This structure will be modified during processing.
- stop_before_step: If set, process_ir will stop normalizing the IR just
- before the specified step. This parameter should be None for non-test
- code.
+ Arguments:
+ ir: The IR to process. This structure will be modified during processing.
+ stop_before_step: If set, process_ir will stop normalizing the IR just
+ before the specified step. This parameter should be None for non-test
+ code.
- Returns:
- (ir, errors), where ir is a complete IR, ready for consumption by an Emboss
- back end, and errors is a list of compilation errors. If errors is not an
- empty list, ir will be None.
- """
- passes = (synthetics.desugar,
- symbol_resolver.resolve_symbols,
- dependency_checker.find_dependency_cycles,
- dependency_checker.set_dependency_order,
- symbol_resolver.resolve_field_references,
- type_check.annotate_types,
- type_check.check_types,
- expression_bounds.compute_constants,
- attribute_checker.normalize_and_verify,
- constraints.check_constraints,
- write_inference.set_write_methods)
- assert stop_before_step in [None] + [f.__name__ for f in passes], (
- "Bad value for stop_before_step.")
- # Some parts of the IR are synthesized from "natural" parts of the IR, before
- # the natural parts have been fully error checked. Because of this, the
- # synthesized parts can have errors; in a couple of cases, they can have
- # errors that show up in an earlier pass than the errors in the natural parts
- # of the IR. As an example:
- #
- # struct Foo:
- # 0 [+1] bits:
- # 0 [+1] Flag flag
- # 1 [+flag] UInt:8 field
- #
- # In this case, the use of `flag` as the size of `field` is incorrect, because
- # `flag` is a boolean, but the size of a field must be an integer.
- #
- # Type checking occurs in two passes: in the first pass, expressions are
- # checked for internal consistency. In the second pass, expression types are
- # checked against their location. The use of `flag` would be caught in the
- # second pass.
- #
- # However, the generated_fields pass will synthesize a $size_in_bytes virtual
- # field that would look like:
- #
- # struct Foo:
- # 0 [+1] bits:
- # 0 [+1] Flag flag
- # 1 [+flag] UInt:8 field
- # let $size_in_bytes = $max(true ? 0 + 1 : 0, true ? 1 + flag : 0)
- #
- # Since `1 + flag` is not internally consistent, this type error would be
- # caught in the first pass, and the user would see a very strange error
- # message that "the right-hand argument of operator `+` must be an integer."
- #
- # In order to avoid showing these kinds of errors to the user, we defer any
- # errors in synthetic parts of the IR. Unless there is a compiler bug, those
- # errors will show up as errors in the natural parts of the IR, which should
- # be much more comprehensible to end users.
- #
- # If, for some reason, there is an error in the synthetic IR, but no error in
- # the natural IR, the synthetic errors will be shown. In this case, the
- # formatting for the synthetic errors will show '[compiler bug]' for the
- # error location, which (hopefully) will provide the end user with a cue that
- # the error is a compiler bug.
- deferred_errors = []
- for function in passes:
- if stop_before_step == function.__name__:
- return (ir, [])
- errors, hidden_errors = error.split_errors(function(ir))
- if errors:
- return (None, errors)
- deferred_errors.extend(hidden_errors)
+ Returns:
+ (ir, errors), where ir is a complete IR, ready for consumption by an Emboss
+ back end, and errors is a list of compilation errors. If errors is not an
+ empty list, ir will be None.
+ """
+ passes = (
+ synthetics.desugar,
+ symbol_resolver.resolve_symbols,
+ dependency_checker.find_dependency_cycles,
+ dependency_checker.set_dependency_order,
+ symbol_resolver.resolve_field_references,
+ type_check.annotate_types,
+ type_check.check_types,
+ expression_bounds.compute_constants,
+ attribute_checker.normalize_and_verify,
+ constraints.check_constraints,
+ write_inference.set_write_methods,
+ )
+ assert stop_before_step in [None] + [
+ f.__name__ for f in passes
+ ], "Bad value for stop_before_step."
+ # Some parts of the IR are synthesized from "natural" parts of the IR, before
+ # the natural parts have been fully error checked. Because of this, the
+ # synthesized parts can have errors; in a couple of cases, they can have
+ # errors that show up in an earlier pass than the errors in the natural parts
+ # of the IR. As an example:
+ #
+ # struct Foo:
+ # 0 [+1] bits:
+ # 0 [+1] Flag flag
+ # 1 [+flag] UInt:8 field
+ #
+ # In this case, the use of `flag` as the size of `field` is incorrect, because
+ # `flag` is a boolean, but the size of a field must be an integer.
+ #
+ # Type checking occurs in two passes: in the first pass, expressions are
+ # checked for internal consistency. In the second pass, expression types are
+ # checked against their location. The use of `flag` would be caught in the
+ # second pass.
+ #
+ # However, the generated_fields pass will synthesize a $size_in_bytes virtual
+ # field that would look like:
+ #
+ # struct Foo:
+ # 0 [+1] bits:
+ # 0 [+1] Flag flag
+ # 1 [+flag] UInt:8 field
+ # let $size_in_bytes = $max(true ? 0 + 1 : 0, true ? 1 + flag : 0)
+ #
+ # Since `1 + flag` is not internally consistent, this type error would be
+ # caught in the first pass, and the user would see a very strange error
+ # message that "the right-hand argument of operator `+` must be an integer."
+ #
+ # In order to avoid showing these kinds of errors to the user, we defer any
+ # errors in synthetic parts of the IR. Unless there is a compiler bug, those
+ # errors will show up as errors in the natural parts of the IR, which should
+ # be much more comprehensible to end users.
+ #
+ # If, for some reason, there is an error in the synthetic IR, but no error in
+ # the natural IR, the synthetic errors will be shown. In this case, the
+ # formatting for the synthetic errors will show '[compiler bug]' for the
+ # error location, which (hopefully) will provide the end user with a cue that
+ # the error is a compiler bug.
+ deferred_errors = []
+ for function in passes:
+ if stop_before_step == function.__name__:
+ return (ir, [])
+ errors, hidden_errors = error.split_errors(function(ir))
+ if errors:
+ return (None, errors)
+ deferred_errors.extend(hidden_errors)
- if deferred_errors:
- return (None, deferred_errors)
+ if deferred_errors:
+ return (None, deferred_errors)
- assert stop_before_step is None, "Bad value for stop_before_step."
- return (ir, [])
+ assert stop_before_step is None, "Bad value for stop_before_step."
+ return (ir, [])
diff --git a/compiler/front_end/glue_test.py b/compiler/front_end/glue_test.py
index 2f2ddc5..1decae3 100644
--- a/compiler/front_end/glue_test.py
+++ b/compiler/front_end/glue_test.py
@@ -30,272 +30,328 @@
_GOLDEN_PATH = ""
_SPAN_SE_LOG_FILE_PATH = _GOLDEN_PATH + "span_se_log_file_status.emb"
-_SPAN_SE_LOG_FILE_EMB = pkgutil.get_data(
- _ROOT_PACKAGE, _SPAN_SE_LOG_FILE_PATH).decode(encoding="UTF-8")
+_SPAN_SE_LOG_FILE_EMB = pkgutil.get_data(_ROOT_PACKAGE, _SPAN_SE_LOG_FILE_PATH).decode(
+ encoding="UTF-8"
+)
_SPAN_SE_LOG_FILE_READER = test_util.dict_file_reader(
- {_SPAN_SE_LOG_FILE_PATH: _SPAN_SE_LOG_FILE_EMB})
-_SPAN_SE_LOG_FILE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.Module,
+ {_SPAN_SE_LOG_FILE_PATH: _SPAN_SE_LOG_FILE_EMB}
+)
+_SPAN_SE_LOG_FILE_IR = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.Module,
pkgutil.get_data(
- _ROOT_PACKAGE,
- _GOLDEN_PATH + "span_se_log_file_status.ir.txt"
- ).decode(encoding="UTF-8"))
+ _ROOT_PACKAGE, _GOLDEN_PATH + "span_se_log_file_status.ir.txt"
+ ).decode(encoding="UTF-8"),
+)
_SPAN_SE_LOG_FILE_PARSE_TREE_TEXT = pkgutil.get_data(
- _ROOT_PACKAGE,
- _GOLDEN_PATH + "span_se_log_file_status.parse_tree.txt"
+ _ROOT_PACKAGE, _GOLDEN_PATH + "span_se_log_file_status.parse_tree.txt"
).decode(encoding="UTF-8")
_SPAN_SE_LOG_FILE_TOKENIZATION_TEXT = pkgutil.get_data(
- _ROOT_PACKAGE,
- _GOLDEN_PATH + "span_se_log_file_status.tokens.txt"
+ _ROOT_PACKAGE, _GOLDEN_PATH + "span_se_log_file_status.tokens.txt"
).decode(encoding="UTF-8")
class FrontEndGlueTest(unittest.TestCase):
- """Tests for front_end.glue."""
+ """Tests for front_end.glue."""
- def test_parse_module(self):
- # parse_module(file) should return the same thing as
- # parse_module_text(text), assuming file can be read.
- main_module, debug_info, errors = glue.parse_module(
- _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER)
- main_module2, debug_info2, errors2 = glue.parse_module_text(
- _SPAN_SE_LOG_FILE_EMB, _SPAN_SE_LOG_FILE_PATH)
- self.assertEqual([], errors)
- self.assertEqual([], errors2)
- self.assertEqual(main_module, main_module2)
- self.assertEqual(debug_info, debug_info2)
+ def test_parse_module(self):
+ # parse_module(file) should return the same thing as
+ # parse_module_text(text), assuming file can be read.
+ main_module, debug_info, errors = glue.parse_module(
+ _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER
+ )
+ main_module2, debug_info2, errors2 = glue.parse_module_text(
+ _SPAN_SE_LOG_FILE_EMB, _SPAN_SE_LOG_FILE_PATH
+ )
+ self.assertEqual([], errors)
+ self.assertEqual([], errors2)
+ self.assertEqual(main_module, main_module2)
+ self.assertEqual(debug_info, debug_info2)
- def test_parse_module_no_such_file(self):
- file_name = "nonexistent.emb"
- ir, debug_info, errors = glue.parse_emboss_file(
- file_name, test_util.dict_file_reader({}))
- self.assertEqual([[
- error.error("nonexistent.emb", _location((1, 1), (1, 1)),
- "Unable to read file."),
- error.note("nonexistent.emb", _location((1, 1), (1, 1)),
- "File 'nonexistent.emb' not found."),
- ]], errors)
- self.assertFalse(file_name in debug_info.modules)
- self.assertFalse(ir)
+ def test_parse_module_no_such_file(self):
+ file_name = "nonexistent.emb"
+ ir, debug_info, errors = glue.parse_emboss_file(
+ file_name, test_util.dict_file_reader({})
+ )
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "nonexistent.emb",
+ _location((1, 1), (1, 1)),
+ "Unable to read file.",
+ ),
+ error.note(
+ "nonexistent.emb",
+ _location((1, 1), (1, 1)),
+ "File 'nonexistent.emb' not found.",
+ ),
+ ]
+ ],
+ errors,
+ )
+ self.assertFalse(file_name in debug_info.modules)
+ self.assertFalse(ir)
- def test_parse_module_tokenization_error(self):
- file_name = "tokens.emb"
- ir, debug_info, errors = glue.parse_emboss_file(
- file_name, test_util.dict_file_reader({file_name: "@"}))
- self.assertTrue(debug_info.modules[file_name].source_code)
- self.assertTrue(errors)
- self.assertEqual("Unrecognized token", errors[0][0].message)
- self.assertFalse(ir)
+ def test_parse_module_tokenization_error(self):
+ file_name = "tokens.emb"
+ ir, debug_info, errors = glue.parse_emboss_file(
+ file_name, test_util.dict_file_reader({file_name: "@"})
+ )
+ self.assertTrue(debug_info.modules[file_name].source_code)
+ self.assertTrue(errors)
+ self.assertEqual("Unrecognized token", errors[0][0].message)
+ self.assertFalse(ir)
- def test_parse_module_indentation_error(self):
- file_name = "indent.emb"
- ir, debug_info, errors = glue.parse_emboss_file(
- file_name, test_util.dict_file_reader(
- {file_name: "struct Foo:\n"
- " 1 [+1] Int x\n"
- " 2 [+1] Int y\n"}))
- self.assertTrue(debug_info.modules[file_name].source_code)
- self.assertTrue(errors)
- self.assertEqual("Bad indentation", errors[0][0].message)
- self.assertFalse(ir)
+ def test_parse_module_indentation_error(self):
+ file_name = "indent.emb"
+ ir, debug_info, errors = glue.parse_emboss_file(
+ file_name,
+ test_util.dict_file_reader(
+ {file_name: "struct Foo:\n" " 1 [+1] Int x\n" " 2 [+1] Int y\n"}
+ ),
+ )
+ self.assertTrue(debug_info.modules[file_name].source_code)
+ self.assertTrue(errors)
+ self.assertEqual("Bad indentation", errors[0][0].message)
+ self.assertFalse(ir)
- def test_parse_module_parse_error(self):
- file_name = "parse.emb"
- ir, debug_info, errors = glue.parse_emboss_file(
- file_name, test_util.dict_file_reader(
- {file_name: "struct foo:\n"
- " 1 [+1] Int x\n"
- " 3 [+1] Int y\n"}))
- self.assertTrue(debug_info.modules[file_name].source_code)
- self.assertEqual([[
- error.error(file_name, _location((1, 8), (1, 11)),
- "A type name must be CamelCase.\n"
- "Found 'foo' (SnakeWord), expected CamelWord.")
- ]], errors)
- self.assertFalse(ir)
+ def test_parse_module_parse_error(self):
+ file_name = "parse.emb"
+ ir, debug_info, errors = glue.parse_emboss_file(
+ file_name,
+ test_util.dict_file_reader(
+ {file_name: "struct foo:\n" " 1 [+1] Int x\n" " 3 [+1] Int y\n"}
+ ),
+ )
+ self.assertTrue(debug_info.modules[file_name].source_code)
+ self.assertEqual(
+ [
+ [
+ error.error(
+ file_name,
+ _location((1, 8), (1, 11)),
+ "A type name must be CamelCase.\n"
+ "Found 'foo' (SnakeWord), expected CamelWord.",
+ )
+ ]
+ ],
+ errors,
+ )
+ self.assertFalse(ir)
- def test_parse_error(self):
- file_name = "parse.emb"
- ir, debug_info, errors = glue.parse_emboss_file(
- file_name, test_util.dict_file_reader(
- {file_name: "struct foo:\n"
- " 1 [+1] Int x\n"
- " 2 [+1] Int y\n"}))
- self.assertTrue(debug_info.modules[file_name].source_code)
- self.assertEqual([[
- error.error(file_name, _location((1, 8), (1, 11)),
- "A type name must be CamelCase.\n"
- "Found 'foo' (SnakeWord), expected CamelWord.")
- ]], errors)
- self.assertFalse(ir)
+ def test_parse_error(self):
+ file_name = "parse.emb"
+ ir, debug_info, errors = glue.parse_emboss_file(
+ file_name,
+ test_util.dict_file_reader(
+ {file_name: "struct foo:\n" " 1 [+1] Int x\n" " 2 [+1] Int y\n"}
+ ),
+ )
+ self.assertTrue(debug_info.modules[file_name].source_code)
+ self.assertEqual(
+ [
+ [
+ error.error(
+ file_name,
+ _location((1, 8), (1, 11)),
+ "A type name must be CamelCase.\n"
+ "Found 'foo' (SnakeWord), expected CamelWord.",
+ )
+ ]
+ ],
+ errors,
+ )
+ self.assertFalse(ir)
- def test_circular_dependency_error(self):
- file_name = "cycle.emb"
- ir, debug_info, errors = glue.parse_emboss_file(
- file_name, test_util.dict_file_reader({
- file_name: "struct Foo:\n"
- " 0 [+field1] UInt field1\n"
- }))
- self.assertTrue(debug_info.modules[file_name].source_code)
- self.assertTrue(errors)
- self.assertEqual("Dependency cycle\nfield1", errors[0][0].message)
- self.assertFalse(ir)
+ def test_circular_dependency_error(self):
+ file_name = "cycle.emb"
+ ir, debug_info, errors = glue.parse_emboss_file(
+ file_name,
+ test_util.dict_file_reader(
+ {file_name: "struct Foo:\n" " 0 [+field1] UInt field1\n"}
+ ),
+ )
+ self.assertTrue(debug_info.modules[file_name].source_code)
+ self.assertTrue(errors)
+ self.assertEqual("Dependency cycle\nfield1", errors[0][0].message)
+ self.assertFalse(ir)
- def test_ir_from_parse_module(self):
- log_file_path_ir = ir_data_utils.copy(_SPAN_SE_LOG_FILE_IR)
- log_file_path_ir.source_file_name = _SPAN_SE_LOG_FILE_PATH
- self.assertEqual(log_file_path_ir, glue.parse_module(
- _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER).ir)
+ def test_ir_from_parse_module(self):
+ log_file_path_ir = ir_data_utils.copy(_SPAN_SE_LOG_FILE_IR)
+ log_file_path_ir.source_file_name = _SPAN_SE_LOG_FILE_PATH
+ self.assertEqual(
+ log_file_path_ir,
+ glue.parse_module(_SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER).ir,
+ )
- def test_debug_info_from_parse_module(self):
- debug_info = glue.parse_module(_SPAN_SE_LOG_FILE_PATH,
- _SPAN_SE_LOG_FILE_READER).debug_info
- self.maxDiff = 200000 # pylint:disable=invalid-name
- self.assertEqual(_SPAN_SE_LOG_FILE_TOKENIZATION_TEXT.strip(),
- debug_info.format_tokenization().strip())
- self.assertEqual(_SPAN_SE_LOG_FILE_PARSE_TREE_TEXT.strip(),
- debug_info.format_parse_tree().strip())
- self.assertEqual(_SPAN_SE_LOG_FILE_IR, debug_info.ir)
- self.assertEqual(ir_data_utils.IrDataSerializer(_SPAN_SE_LOG_FILE_IR).to_json(indent=2),
- debug_info.format_module_ir())
+ def test_debug_info_from_parse_module(self):
+ debug_info = glue.parse_module(
+ _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER
+ ).debug_info
+ self.maxDiff = 200000 # pylint:disable=invalid-name
+ self.assertEqual(
+ _SPAN_SE_LOG_FILE_TOKENIZATION_TEXT.strip(),
+ debug_info.format_tokenization().strip(),
+ )
+ self.assertEqual(
+ _SPAN_SE_LOG_FILE_PARSE_TREE_TEXT.strip(),
+ debug_info.format_parse_tree().strip(),
+ )
+ self.assertEqual(_SPAN_SE_LOG_FILE_IR, debug_info.ir)
+ self.assertEqual(
+ ir_data_utils.IrDataSerializer(_SPAN_SE_LOG_FILE_IR).to_json(indent=2),
+ debug_info.format_module_ir(),
+ )
- def test_parse_emboss_file(self):
- # parse_emboss_file calls parse_module, wraps its results, and calls
- # symbol_resolver.resolve_symbols() on the resulting IR.
- ir, debug_info, errors = glue.parse_emboss_file(_SPAN_SE_LOG_FILE_PATH,
- _SPAN_SE_LOG_FILE_READER)
- module_ir, module_debug_info, module_errors = glue.parse_module(
- _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER)
- self.assertEqual([], errors)
- self.assertEqual([], module_errors)
- self.assertTrue(test_util.proto_is_superset(ir.module[0], module_ir))
- self.assertEqual(module_debug_info,
- debug_info.modules[_SPAN_SE_LOG_FILE_PATH])
- self.assertEqual(2, len(debug_info.modules))
- self.assertEqual(2, len(ir.module))
- self.assertEqual(_SPAN_SE_LOG_FILE_PATH, ir.module[0].source_file_name)
- self.assertEqual("", ir.module[1].source_file_name)
+ def test_parse_emboss_file(self):
+ # parse_emboss_file calls parse_module, wraps its results, and calls
+ # symbol_resolver.resolve_symbols() on the resulting IR.
+ ir, debug_info, errors = glue.parse_emboss_file(
+ _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER
+ )
+ module_ir, module_debug_info, module_errors = glue.parse_module(
+ _SPAN_SE_LOG_FILE_PATH, _SPAN_SE_LOG_FILE_READER
+ )
+ self.assertEqual([], errors)
+ self.assertEqual([], module_errors)
+ self.assertTrue(test_util.proto_is_superset(ir.module[0], module_ir))
+ self.assertEqual(module_debug_info, debug_info.modules[_SPAN_SE_LOG_FILE_PATH])
+ self.assertEqual(2, len(debug_info.modules))
+ self.assertEqual(2, len(ir.module))
+ self.assertEqual(_SPAN_SE_LOG_FILE_PATH, ir.module[0].source_file_name)
+ self.assertEqual("", ir.module[1].source_file_name)
- def test_synthetic_error(self):
- file_name = "missing_byte_order_attribute.emb"
- ir, unused_debug_info, errors = glue.only_parse_emboss_file(
- file_name, test_util.dict_file_reader({
- file_name: "struct Foo:\n"
- " 0 [+8] UInt field\n"
- }))
- self.assertFalse(errors)
- # Artificially mark the first field as is_synthetic.
- first_field = ir.module[0].type[0].structure.field[0]
- first_field.source_location.is_synthetic = True
- ir, errors = glue.process_ir(ir, None)
- self.assertTrue(errors)
- self.assertEqual("Attribute 'byte_order' required on field which is byte "
- "order dependent.", errors[0][0].message)
- self.assertTrue(errors[0][0].location.is_synthetic)
- self.assertFalse(ir)
+ def test_synthetic_error(self):
+ file_name = "missing_byte_order_attribute.emb"
+ ir, unused_debug_info, errors = glue.only_parse_emboss_file(
+ file_name,
+ test_util.dict_file_reader(
+ {file_name: "struct Foo:\n" " 0 [+8] UInt field\n"}
+ ),
+ )
+ self.assertFalse(errors)
+ # Artificially mark the first field as is_synthetic.
+ first_field = ir.module[0].type[0].structure.field[0]
+ first_field.source_location.is_synthetic = True
+ ir, errors = glue.process_ir(ir, None)
+ self.assertTrue(errors)
+ self.assertEqual(
+ "Attribute 'byte_order' required on field which is byte "
+ "order dependent.",
+ errors[0][0].message,
+ )
+ self.assertTrue(errors[0][0].location.is_synthetic)
+ self.assertFalse(ir)
- def test_suppressed_synthetic_error(self):
- file_name = "triplicate_symbol.emb"
- ir, unused_debug_info, errors = glue.only_parse_emboss_file(
- file_name, test_util.dict_file_reader({
- file_name: "struct Foo:\n"
- " 0 [+1] UInt field\n"
- " 1 [+1] UInt field\n"
- " 2 [+1] UInt field\n"
- }))
- self.assertFalse(errors)
- # Artificially mark the name of the second field as is_synthetic.
- second_field = ir.module[0].type[0].structure.field[1]
- second_field.name.source_location.is_synthetic = True
- second_field.name.name.source_location.is_synthetic = True
- ir, errors = glue.process_ir(ir, None)
- self.assertEqual(1, len(errors))
- self.assertEqual("Duplicate name 'field'", errors[0][0].message)
- self.assertFalse(errors[0][0].location.is_synthetic)
- self.assertFalse(errors[0][1].location.is_synthetic)
- self.assertFalse(ir)
+ def test_suppressed_synthetic_error(self):
+ file_name = "triplicate_symbol.emb"
+ ir, unused_debug_info, errors = glue.only_parse_emboss_file(
+ file_name,
+ test_util.dict_file_reader(
+ {
+ file_name: "struct Foo:\n"
+ " 0 [+1] UInt field\n"
+ " 1 [+1] UInt field\n"
+ " 2 [+1] UInt field\n"
+ }
+ ),
+ )
+ self.assertFalse(errors)
+ # Artificially mark the name of the second field as is_synthetic.
+ second_field = ir.module[0].type[0].structure.field[1]
+ second_field.name.source_location.is_synthetic = True
+ second_field.name.name.source_location.is_synthetic = True
+ ir, errors = glue.process_ir(ir, None)
+ self.assertEqual(1, len(errors))
+ self.assertEqual("Duplicate name 'field'", errors[0][0].message)
+ self.assertFalse(errors[0][0].location.is_synthetic)
+ self.assertFalse(errors[0][1].location.is_synthetic)
+ self.assertFalse(ir)
class DebugInfoTest(unittest.TestCase):
- """Tests for DebugInfo and ModuleDebugInfo classes."""
+ """Tests for DebugInfo and ModuleDebugInfo classes."""
- def test_debug_info_initialization(self):
- debug_info = glue.DebugInfo()
- self.assertEqual({}, debug_info.modules)
+ def test_debug_info_initialization(self):
+ debug_info = glue.DebugInfo()
+ self.assertEqual({}, debug_info.modules)
- def test_debug_info_invalid_attribute_set(self):
- debug_info = glue.DebugInfo()
- with self.assertRaises(AttributeError):
- debug_info.foo = "foo"
+ def test_debug_info_invalid_attribute_set(self):
+ debug_info = glue.DebugInfo()
+ with self.assertRaises(AttributeError):
+ debug_info.foo = "foo"
- def test_debug_info_equality(self):
- debug_info = glue.DebugInfo()
- debug_info2 = glue.DebugInfo()
- self.assertEqual(debug_info, debug_info2)
- debug_info.modules["foo"] = glue.ModuleDebugInfo("foo")
- self.assertNotEqual(debug_info, debug_info2)
- debug_info2.modules["foo"] = glue.ModuleDebugInfo("foo")
- self.assertEqual(debug_info, debug_info2)
+ def test_debug_info_equality(self):
+ debug_info = glue.DebugInfo()
+ debug_info2 = glue.DebugInfo()
+ self.assertEqual(debug_info, debug_info2)
+ debug_info.modules["foo"] = glue.ModuleDebugInfo("foo")
+ self.assertNotEqual(debug_info, debug_info2)
+ debug_info2.modules["foo"] = glue.ModuleDebugInfo("foo")
+ self.assertEqual(debug_info, debug_info2)
- def test_module_debug_info_initialization(self):
- module_info = glue.ModuleDebugInfo("bar.emb")
- self.assertEqual("bar.emb", module_info.file_name)
- self.assertEqual(None, module_info.tokens)
- self.assertEqual(None, module_info.parse_tree)
- self.assertEqual(None, module_info.ir)
- self.assertEqual(None, module_info.used_productions)
+ def test_module_debug_info_initialization(self):
+ module_info = glue.ModuleDebugInfo("bar.emb")
+ self.assertEqual("bar.emb", module_info.file_name)
+ self.assertEqual(None, module_info.tokens)
+ self.assertEqual(None, module_info.parse_tree)
+ self.assertEqual(None, module_info.ir)
+ self.assertEqual(None, module_info.used_productions)
- def test_module_debug_info_attribute_set(self):
- module_info = glue.ModuleDebugInfo("bar.emb")
- module_info.tokens = "a"
- module_info.parse_tree = "b"
- module_info.ir = "c"
- module_info.used_productions = "d"
- module_info.source_code = "e"
- self.assertEqual("a", module_info.tokens)
- self.assertEqual("b", module_info.parse_tree)
- self.assertEqual("c", module_info.ir)
- self.assertEqual("d", module_info.used_productions)
- self.assertEqual("e", module_info.source_code)
+ def test_module_debug_info_attribute_set(self):
+ module_info = glue.ModuleDebugInfo("bar.emb")
+ module_info.tokens = "a"
+ module_info.parse_tree = "b"
+ module_info.ir = "c"
+ module_info.used_productions = "d"
+ module_info.source_code = "e"
+ self.assertEqual("a", module_info.tokens)
+ self.assertEqual("b", module_info.parse_tree)
+ self.assertEqual("c", module_info.ir)
+ self.assertEqual("d", module_info.used_productions)
+ self.assertEqual("e", module_info.source_code)
- def test_module_debug_info_bad_attribute_set(self):
- module_info = glue.ModuleDebugInfo("bar.emb")
- with self.assertRaises(AttributeError):
- module_info.foo = "foo"
+ def test_module_debug_info_bad_attribute_set(self):
+ module_info = glue.ModuleDebugInfo("bar.emb")
+ with self.assertRaises(AttributeError):
+ module_info.foo = "foo"
- def test_module_debug_info_equality(self):
- module_info = glue.ModuleDebugInfo("foo")
- module_info2 = glue.ModuleDebugInfo("foo")
- module_info_bar = glue.ModuleDebugInfo("bar")
- self.assertEqual(module_info, module_info2)
- module_info_bar = glue.ModuleDebugInfo("bar")
- self.assertNotEqual(module_info, module_info_bar)
- module_info.tokens = []
- self.assertNotEqual(module_info, module_info2)
- module_info2.tokens = []
- self.assertEqual(module_info, module_info2)
- module_info.parse_tree = []
- self.assertNotEqual(module_info, module_info2)
- module_info2.parse_tree = []
- self.assertEqual(module_info, module_info2)
- module_info.ir = []
- self.assertNotEqual(module_info, module_info2)
- module_info2.ir = []
- self.assertEqual(module_info, module_info2)
- module_info.used_productions = []
- self.assertNotEqual(module_info, module_info2)
- module_info2.used_productions = []
- self.assertEqual(module_info, module_info2)
+ def test_module_debug_info_equality(self):
+ module_info = glue.ModuleDebugInfo("foo")
+ module_info2 = glue.ModuleDebugInfo("foo")
+ module_info_bar = glue.ModuleDebugInfo("bar")
+ self.assertEqual(module_info, module_info2)
+ module_info_bar = glue.ModuleDebugInfo("bar")
+ self.assertNotEqual(module_info, module_info_bar)
+ module_info.tokens = []
+ self.assertNotEqual(module_info, module_info2)
+ module_info2.tokens = []
+ self.assertEqual(module_info, module_info2)
+ module_info.parse_tree = []
+ self.assertNotEqual(module_info, module_info2)
+ module_info2.parse_tree = []
+ self.assertEqual(module_info, module_info2)
+ module_info.ir = []
+ self.assertNotEqual(module_info, module_info2)
+ module_info2.ir = []
+ self.assertEqual(module_info, module_info2)
+ module_info.used_productions = []
+ self.assertNotEqual(module_info, module_info2)
+ module_info2.used_productions = []
+ self.assertEqual(module_info, module_info2)
class TestFormatProductionSet(unittest.TestCase):
- """Tests for format_production_set."""
+ """Tests for format_production_set."""
- def test_format_production_set(self):
- production_texts = ["A -> B", "B -> C", "A -> C", "C -> A"]
- productions = [parser_types.Production.parse(p) for p in production_texts]
- self.assertEqual("\n".join(sorted(production_texts)),
- glue.format_production_set(set(productions)))
+ def test_format_production_set(self):
+ production_texts = ["A -> B", "B -> C", "A -> C", "C -> A"]
+ productions = [parser_types.Production.parse(p) for p in production_texts]
+ self.assertEqual(
+ "\n".join(sorted(production_texts)),
+ glue.format_production_set(set(productions)),
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/lr1.py b/compiler/front_end/lr1.py
index 41111d1..be99f95 100644
--- a/compiler/front_end/lr1.py
+++ b/compiler/front_end/lr1.py
@@ -31,81 +31,97 @@
from compiler.util import parser_types
-class Item(collections.namedtuple("Item", ["production", "dot", "terminal",
- "next_symbol"])):
- """An Item is an LR(1) Item: a production, a cursor location, and a terminal.
+class Item(
+ collections.namedtuple("Item", ["production", "dot", "terminal", "next_symbol"])
+):
+ """An Item is an LR(1) Item: a production, a cursor location, and a terminal.
- An Item represents a partially-parsed production, and a lookahead symbol. The
- position of the dot indicates what portion of the production has been parsed.
- Generally, Items are an internal implementation detail, but they can be useful
- elsewhere, particularly for debugging.
+ An Item represents a partially-parsed production, and a lookahead symbol. The
+ position of the dot indicates what portion of the production has been parsed.
+ Generally, Items are an internal implementation detail, but they can be useful
+ elsewhere, particularly for debugging.
- Attributes:
- production: The Production this Item covers.
- dot: The index of the "dot" in production's rhs.
- terminal: The terminal lookahead symbol that follows the production in the
- input stream.
- """
-
- def __str__(self):
- """__str__ generates ASLU notation."""
- return (str(self.production.lhs) + " -> " + " ".join(
- [str(r) for r in self.production.rhs[0:self.dot] + (".",) +
- self.production.rhs[self.dot:]]) + ", " + str(self.terminal))
-
- @staticmethod
- def parse(text):
- """Parses an Item in ALSU notation.
-
- Parses an Item from notation like:
-
- symbol -> foo . bar baz, qux
-
- where "symbol -> foo bar baz" will be taken as the production, the position
- of the "." is taken as "dot" (in this case 1), and the symbol after "," is
- taken as the "terminal". The following are also valid items:
-
- sym -> ., foo
- sym -> . foo bar, baz
- sym -> foo bar ., baz
-
- Symbols on the right-hand side of the production should be separated by
- whitespace.
-
- Arguments:
- text: The text to parse into an Item.
-
- Returns:
- An Item.
+ Attributes:
+ production: The Production this Item covers.
+ dot: The index of the "dot" in production's rhs.
+ terminal: The terminal lookahead symbol that follows the production in the
+ input stream.
"""
- production, terminal = text.split(",")
- terminal = terminal.strip()
- if terminal == "$":
- terminal = END_OF_INPUT
- lhs, rhs = production.split("->")
- lhs = lhs.strip()
- if lhs == "S'":
- lhs = START_PRIME
- before_dot, after_dot = rhs.split(".")
- handle = before_dot.split()
- tail = after_dot.split()
- return make_item(parser_types.Production(lhs, tuple(handle + tail)),
- len(handle), terminal)
+
+ def __str__(self):
+ """__str__ generates ASLU notation."""
+ return (
+ str(self.production.lhs)
+ + " -> "
+ + " ".join(
+ [
+ str(r)
+ for r in self.production.rhs[0 : self.dot]
+ + (".",)
+ + self.production.rhs[self.dot :]
+ ]
+ )
+ + ", "
+ + str(self.terminal)
+ )
+
+ @staticmethod
+ def parse(text):
+ """Parses an Item in ALSU notation.
+
+ Parses an Item from notation like:
+
+ symbol -> foo . bar baz, qux
+
+ where "symbol -> foo bar baz" will be taken as the production, the position
+ of the "." is taken as "dot" (in this case 1), and the symbol after "," is
+ taken as the "terminal". The following are also valid items:
+
+ sym -> ., foo
+ sym -> . foo bar, baz
+ sym -> foo bar ., baz
+
+ Symbols on the right-hand side of the production should be separated by
+ whitespace.
+
+ Arguments:
+ text: The text to parse into an Item.
+
+ Returns:
+ An Item.
+ """
+ production, terminal = text.split(",")
+ terminal = terminal.strip()
+ if terminal == "$":
+ terminal = END_OF_INPUT
+ lhs, rhs = production.split("->")
+ lhs = lhs.strip()
+ if lhs == "S'":
+ lhs = START_PRIME
+ before_dot, after_dot = rhs.split(".")
+ handle = before_dot.split()
+ tail = after_dot.split()
+ return make_item(
+ parser_types.Production(lhs, tuple(handle + tail)), len(handle), terminal
+ )
def make_item(production, dot, symbol):
- return Item(production, dot, symbol,
- None if dot >= len(production.rhs) else production.rhs[dot])
+ return Item(
+ production,
+ dot,
+ symbol,
+ None if dot >= len(production.rhs) else production.rhs[dot],
+ )
-class Conflict(
- collections.namedtuple("Conflict", ["state", "symbol", "actions"])
-):
- """Conflict represents a parse conflict."""
+class Conflict(collections.namedtuple("Conflict", ["state", "symbol", "actions"])):
+ """Conflict represents a parse conflict."""
- def __str__(self):
- return "Conflict for {} in state {}: ".format(
- self.symbol, self.state) + " vs ".join([str(a) for a in self.actions])
+ def __str__(self):
+ return "Conflict for {} in state {}: ".format(
+ self.symbol, self.state
+ ) + " vs ".join([str(a) for a in self.actions])
Shift = collections.namedtuple("Shift", ["state", "items"])
@@ -123,637 +139,695 @@
# ANY_TOKEN is used by mark_error as a "wildcard" token that should be replaced
# by every other token.
-ANY_TOKEN = parser_types.Token(object(), "*",
- parser_types.parse_location("0:0-0:0"))
+ANY_TOKEN = parser_types.Token(object(), "*", parser_types.parse_location("0:0-0:0"))
-class Reduction(collections.namedtuple("Reduction",
- ["symbol", "children", "production",
- "source_location"])):
- """A Reduction is a non-leaf node in a parse tree.
+class Reduction(
+ collections.namedtuple(
+ "Reduction", ["symbol", "children", "production", "source_location"]
+ )
+):
+ """A Reduction is a non-leaf node in a parse tree.
- Attributes:
- symbol: The name of this element in the parse.
- children: The child elements of this parse.
- production: The grammar production to which this reduction corresponds.
- source_location: If known, the range in the source text corresponding to the
- tokens from which this reduction was parsed. May be 'None' if this
- reduction was produced from no symbols, or if the tokens fed to `parse`
- did not include source_location.
- """
- pass
+ Attributes:
+ symbol: The name of this element in the parse.
+ children: The child elements of this parse.
+ production: The grammar production to which this reduction corresponds.
+ source_location: If known, the range in the source text corresponding to the
+ tokens from which this reduction was parsed. May be 'None' if this
+ reduction was produced from no symbols, or if the tokens fed to `parse`
+ did not include source_location.
+ """
+
+ pass
class Grammar(object):
- """Grammar is an LR(1) context-free grammar.
+ """Grammar is an LR(1) context-free grammar.
- Attributes:
- start: The start symbol for the grammar.
- productions: A list of productions in the grammar, including the S' -> start
- production.
- symbols: A set of all symbols in the grammar, including $ and S'.
- nonterminals: A set of all nonterminal symbols in the grammar, including S'.
- terminals: A set of all terminal symbols in the grammar, including $.
- """
-
- def __init__(self, start_symbol, productions):
- """Constructs a Grammar object.
-
- Arguments:
- start_symbol: The start symbol for the grammar.
- productions: A list of productions (not including the "S' -> start_symbol"
- production).
+ Attributes:
+ start: The start symbol for the grammar.
+ productions: A list of productions in the grammar, including the S' -> start
+ production.
+ symbols: A set of all symbols in the grammar, including $ and S'.
+ nonterminals: A set of all nonterminal symbols in the grammar, including S'.
+ terminals: A set of all terminal symbols in the grammar, including $.
"""
- object.__init__(self)
- self.start = start_symbol
- self._seed_production = parser_types.Production(START_PRIME, (self.start,))
- self.productions = productions + [self._seed_production]
- self._single_level_closure_of_item_cache = {}
- self._closure_of_item_cache = {}
- self._compute_symbols()
- self._compute_seed_firsts()
- self._set_productions_by_lhs()
- self._populate_item_cache()
+ def __init__(self, start_symbol, productions):
+ """Constructs a Grammar object.
- def _set_productions_by_lhs(self):
- # Prepopulating _productions_by_lhs speeds up _closure_of_item by about 30%,
- # which is significant on medium-to-large grammars.
- self._productions_by_lhs = {}
- for production in self.productions:
- self._productions_by_lhs.setdefault(production.lhs, list()).append(
- production)
+ Arguments:
+ start_symbol: The start symbol for the grammar.
+ productions: A list of productions (not including the "S' -> start_symbol"
+ production).
+ """
+ object.__init__(self)
+ self.start = start_symbol
+ self._seed_production = parser_types.Production(START_PRIME, (self.start,))
+ self.productions = productions + [self._seed_production]
- def _populate_item_cache(self):
- # There are a relatively small number of possible Items for a grammar, and
- # the algorithm needs to get Items from their constituent components very
- # frequently. As it turns out, pre-caching all possible Items results in a
- # ~35% overall speedup to Grammar.parser().
- self._item_cache = {}
- for symbol in self.terminals:
- for production in self.productions:
- for dot in range(len(production.rhs) + 1):
- self._item_cache[production, dot, symbol] = make_item(
- production, dot, symbol)
+ self._single_level_closure_of_item_cache = {}
+ self._closure_of_item_cache = {}
+ self._compute_symbols()
+ self._compute_seed_firsts()
+ self._set_productions_by_lhs()
+ self._populate_item_cache()
- def _compute_symbols(self):
- """Finds all grammar symbols, and sorts them into terminal and non-terminal.
+ def _set_productions_by_lhs(self):
+ # Prepopulating _productions_by_lhs speeds up _closure_of_item by about 30%,
+ # which is significant on medium-to-large grammars.
+ self._productions_by_lhs = {}
+ for production in self.productions:
+ self._productions_by_lhs.setdefault(production.lhs, list()).append(
+ production
+ )
- Nonterminal symbols are those which appear on the left side of any
- production. Terminal symbols are those which do not.
+ def _populate_item_cache(self):
+ # There are a relatively small number of possible Items for a grammar, and
+ # the algorithm needs to get Items from their constituent components very
+ # frequently. As it turns out, pre-caching all possible Items results in a
+ # ~35% overall speedup to Grammar.parser().
+ self._item_cache = {}
+ for symbol in self.terminals:
+ for production in self.productions:
+ for dot in range(len(production.rhs) + 1):
+ self._item_cache[production, dot, symbol] = make_item(
+ production, dot, symbol
+ )
- _compute_symbols is used during __init__.
- """
- self.symbols = {END_OF_INPUT}
- self.nonterminals = set()
- for production in self.productions:
- self.symbols.add(production.lhs)
- self.nonterminals.add(production.lhs)
- for symbol in production.rhs:
- self.symbols.add(symbol)
- self.terminals = self.symbols - self.nonterminals
+ def _compute_symbols(self):
+ """Finds all grammar symbols, and sorts them into terminal and non-terminal.
- def _compute_seed_firsts(self):
- """Computes FIRST (ALSU p221) for all terminal and nonterminal symbols.
+ Nonterminal symbols are those which appear on the left side of any
+ production. Terminal symbols are those which do not.
- The algorithm for computing FIRST is an iterative one that terminates when
- it reaches a fixed point (that is, when further iterations stop changing
- state). _compute_seed_firsts computes the fixed point for all single-symbol
- strings, by repeatedly calling _first and updating the internal _firsts
- table with the results.
+ _compute_symbols is used during __init__.
+ """
+ self.symbols = {END_OF_INPUT}
+ self.nonterminals = set()
+ for production in self.productions:
+ self.symbols.add(production.lhs)
+ self.nonterminals.add(production.lhs)
+ for symbol in production.rhs:
+ self.symbols.add(symbol)
+ self.terminals = self.symbols - self.nonterminals
- Once _compute_seed_firsts has completed, _first will return correct results
- for both single- and multi-symbol strings.
+ def _compute_seed_firsts(self):
+ """Computes FIRST (ALSU p221) for all terminal and nonterminal symbols.
- _compute_seed_firsts is used during __init__.
- """
- self.firsts = {}
- # FIRST for a terminal symbol is always just that terminal symbol.
- for terminal in self.terminals:
- self.firsts[terminal] = set([terminal])
- for nonterminal in self.nonterminals:
- self.firsts[nonterminal] = set()
- while True:
- # The first iteration picks up all the productions that start with
- # terminal symbols. The second iteration picks up productions that start
- # with nonterminals that the first iteration picked up. The third
- # iteration picks up nonterminals that the first and second picked up, and
- # so on.
- #
- # This is guaranteed to end, in the worst case, when every terminal
- # symbol and epsilon has been added to the _firsts set for every
- # nonterminal symbol. This would be slow, but requires a pathological
- # grammar; useful grammars should complete in only a few iterations.
- firsts_to_add = {}
- for production in self.productions:
- for first in self._first(production.rhs):
- if first not in self.firsts[production.lhs]:
- if production.lhs not in firsts_to_add:
- firsts_to_add[production.lhs] = set()
- firsts_to_add[production.lhs].add(first)
- if not firsts_to_add:
- break
- for symbol in firsts_to_add:
- self.firsts[symbol].update(firsts_to_add[symbol])
+ The algorithm for computing FIRST is an iterative one that terminates when
+ it reaches a fixed point (that is, when further iterations stop changing
+ state). _compute_seed_firsts computes the fixed point for all single-symbol
+ strings, by repeatedly calling _first and updating the internal _firsts
+ table with the results.
- def _first(self, symbols):
- """The FIRST function from ALSU p221.
+ Once _compute_seed_firsts has completed, _first will return correct results
+ for both single- and multi-symbol strings.
- _first takes a string of symbols (both terminals and nonterminals) and
- returns the set of terminal symbols which could be the first terminal symbol
- of a string produced by the given list of symbols.
+ _compute_seed_firsts is used during __init__.
+ """
+ self.firsts = {}
+ # FIRST for a terminal symbol is always just that terminal symbol.
+ for terminal in self.terminals:
+ self.firsts[terminal] = set([terminal])
+ for nonterminal in self.nonterminals:
+ self.firsts[nonterminal] = set()
+ while True:
+ # The first iteration picks up all the productions that start with
+ # terminal symbols. The second iteration picks up productions that start
+ # with nonterminals that the first iteration picked up. The third
+ # iteration picks up nonterminals that the first and second picked up, and
+ # so on.
+ #
+ # This is guaranteed to end, in the worst case, when every terminal
+ # symbol and epsilon has been added to the _firsts set for every
+ # nonterminal symbol. This would be slow, but requires a pathological
+ # grammar; useful grammars should complete in only a few iterations.
+ firsts_to_add = {}
+ for production in self.productions:
+ for first in self._first(production.rhs):
+ if first not in self.firsts[production.lhs]:
+ if production.lhs not in firsts_to_add:
+ firsts_to_add[production.lhs] = set()
+ firsts_to_add[production.lhs].add(first)
+ if not firsts_to_add:
+ break
+ for symbol in firsts_to_add:
+ self.firsts[symbol].update(firsts_to_add[symbol])
- _first will not give fully-correct results until _compute_seed_firsts
- finishes, but is called by _compute_seed_firsts, and must provide partial
- results during that method's execution.
+ def _first(self, symbols):
+ """The FIRST function from ALSU p221.
- Args:
- symbols: A list of symbols.
+ _first takes a string of symbols (both terminals and nonterminals) and
+ returns the set of terminal symbols which could be the first terminal symbol
+ of a string produced by the given list of symbols.
- Returns:
- A set of terminals which could be the first terminal in "symbols."
- """
- result = set()
- all_contain_epsilon = True
- for symbol in symbols:
- for first in self.firsts[symbol]:
- if first:
- result.add(first)
- if None not in self.firsts[symbol]:
- all_contain_epsilon = False
- break
- if all_contain_epsilon:
- # "None" seems like a Pythonic way of representing epsilon (no symbol).
- result.add(None)
- return result
+ _first will not give fully-correct results until _compute_seed_firsts
+ finishes, but is called by _compute_seed_firsts, and must provide partial
+ results during that method's execution.
- def _closure_of_item(self, root_item):
- """Modified implementation of CLOSURE from ALSU p261.
+ Args:
+ symbols: A list of symbols.
- _closure_of_item performs the CLOSURE function with a single seed item, with
- memoization. In the algorithm as presented in ALSU, CLOSURE is called with
- a different set of items every time, which is unhelpful for memoization.
- Instead, we let _parallel_goto merge the sets returned by _closure_of_item,
- which results in a ~40% speedup.
+ Returns:
+ A set of terminals which could be the first terminal in "symbols."
+ """
+ result = set()
+ all_contain_epsilon = True
+ for symbol in symbols:
+ for first in self.firsts[symbol]:
+ if first:
+ result.add(first)
+ if None not in self.firsts[symbol]:
+ all_contain_epsilon = False
+ break
+ if all_contain_epsilon:
+ # "None" seems like a Pythonic way of representing epsilon (no symbol).
+ result.add(None)
+ return result
- CLOSURE, roughly, computes the set of LR(1) Items which might be active when
- a "seed" set of Items is active.
+ def _closure_of_item(self, root_item):
+ """Modified implementation of CLOSURE from ALSU p261.
- Technically, it is the epsilon-closure of the NFA states represented by
- "items," where an epsilon transition (a transition that does not consume any
- symbols) occurs from a->Z.bY,q to b->.X,p when p is in FIRST(Yq). (a and b
- are nonterminals, X, Y, and Z are arbitrary strings of symbols, and p and q
- are terminals.) That is, it is the set of all NFA states which can be
- reached from "items" without consuming any input. This set corresponds to a
- single DFA state.
+ _closure_of_item performs the CLOSURE function with a single seed item, with
+ memoization. In the algorithm as presented in ALSU, CLOSURE is called with
+ a different set of items every time, which is unhelpful for memoization.
+ Instead, we let _parallel_goto merge the sets returned by _closure_of_item,
+ which results in a ~40% speedup.
- Args:
- root_item: The initial LR(1) Item.
+ CLOSURE, roughly, computes the set of LR(1) Items which might be active when
+ a "seed" set of Items is active.
- Returns:
- A set of LR(1) items which may be active at the time when the provided
- item is active.
- """
- if root_item in self._closure_of_item_cache:
- return self._closure_of_item_cache[root_item]
- item_set = set([root_item])
- item_list = [root_item]
- i = 0
- # Each newly-added Item may trigger the addition of further Items, so
- # iterate until no new Items are added. In the worst case, a new Item will
- # be added for each production.
- #
- # This algorithm is really looking for "next" nonterminals in the existing
- # items, and adding new items corresponding to their productions.
- while i < len(item_list):
- item = item_list[i]
- i += 1
- if not item.next_symbol:
- continue
- # If _closure_of_item_cache contains the full closure of item, then we can
- # add its full closure to the result set, and skip checking any of its
- # items: any item that would be added by any item in the cached result
- # will already be in the _closure_of_item_cache entry.
- if item in self._closure_of_item_cache:
- item_set |= self._closure_of_item_cache[item]
- continue
- # Even if we don't have the full closure of item, we may have the
- # immediate closure of item. It turns out that memoizing just this step
- # speeds up this function by about 50%, even after the
- # _closure_of_item_cache check.
- if item not in self._single_level_closure_of_item_cache:
- new_items = set()
- for production in self._productions_by_lhs.get(item.next_symbol, []):
- for terminal in self._first(item.production.rhs[item.dot + 1:] +
- (item.terminal,)):
- new_items.add(self._item_cache[production, 0, terminal])
- self._single_level_closure_of_item_cache[item] = new_items
- for new_item in self._single_level_closure_of_item_cache[item]:
- if new_item not in item_set:
- item_set.add(new_item)
- item_list.append(new_item)
- self._closure_of_item_cache[root_item] = item_set
- # Typically, _closure_of_item() will be called on items whose closures
- # bring in the greatest number of additional items, then on items which
- # close over fewer and fewer other items. Since items are not added to
- # _closure_of_item_cache unless _closure_of_item() is called directly on
- # them, this means that it is unlikely that items brought in will (without
- # intervention) have entries in _closure_of_item_cache, which slows down the
- # computation of the larger closures.
- #
- # Although it is not guaranteed, items added to item_list last will tend to
- # close over fewer items, and therefore be easier to compute. By forcibly
- # re-calculating closures from last to first, and adding the results to
- # _closure_of_item_cache at each step, we get a modest performance
- # improvement: roughly 50% less time spent in _closure_of_item, which
- # translates to about 5% less time in parser().
- for item in item_list[::-1]:
- self._closure_of_item(item)
- return item_set
+ Technically, it is the epsilon-closure of the NFA states represented by
+ "items," where an epsilon transition (a transition that does not consume any
+ symbols) occurs from a->Z.bY,q to b->.X,p when p is in FIRST(Yq). (a and b
+ are nonterminals, X, Y, and Z are arbitrary strings of symbols, and p and q
+ are terminals.) That is, it is the set of all NFA states which can be
+ reached from "items" without consuming any input. This set corresponds to a
+ single DFA state.
- def _parallel_goto(self, items):
- """The GOTO function from ALSU p261, executed on all symbols.
+ Args:
+ root_item: The initial LR(1) Item.
- _parallel_goto takes a set of Items, and returns a dict from every symbol in
- self.symbols to the set of Items that would be active after a shift
- operation (if symbol is a terminal) or after a reduction operation (if
- symbol is a nonterminal).
+ Returns:
+ A set of LR(1) items which may be active at the time when the provided
+ item is active.
+ """
+ if root_item in self._closure_of_item_cache:
+ return self._closure_of_item_cache[root_item]
+ item_set = set([root_item])
+ item_list = [root_item]
+ i = 0
+ # Each newly-added Item may trigger the addition of further Items, so
+ # iterate until no new Items are added. In the worst case, a new Item will
+ # be added for each production.
+ #
+ # This algorithm is really looking for "next" nonterminals in the existing
+ # items, and adding new items corresponding to their productions.
+ while i < len(item_list):
+ item = item_list[i]
+ i += 1
+ if not item.next_symbol:
+ continue
+ # If _closure_of_item_cache contains the full closure of item, then we can
+ # add its full closure to the result set, and skip checking any of its
+ # items: any item that would be added by any item in the cached result
+ # will already be in the _closure_of_item_cache entry.
+ if item in self._closure_of_item_cache:
+ item_set |= self._closure_of_item_cache[item]
+ continue
+ # Even if we don't have the full closure of item, we may have the
+ # immediate closure of item. It turns out that memoizing just this step
+ # speeds up this function by about 50%, even after the
+ # _closure_of_item_cache check.
+ if item not in self._single_level_closure_of_item_cache:
+ new_items = set()
+ for production in self._productions_by_lhs.get(item.next_symbol, []):
+ for terminal in self._first(
+ item.production.rhs[item.dot + 1 :] + (item.terminal,)
+ ):
+ new_items.add(self._item_cache[production, 0, terminal])
+ self._single_level_closure_of_item_cache[item] = new_items
+ for new_item in self._single_level_closure_of_item_cache[item]:
+ if new_item not in item_set:
+ item_set.add(new_item)
+ item_list.append(new_item)
+ self._closure_of_item_cache[root_item] = item_set
+ # Typically, _closure_of_item() will be called on items whose closures
+ # bring in the greatest number of additional items, then on items which
+ # close over fewer and fewer other items. Since items are not added to
+ # _closure_of_item_cache unless _closure_of_item() is called directly on
+ # them, this means that it is unlikely that items brought in will (without
+ # intervention) have entries in _closure_of_item_cache, which slows down the
+ # computation of the larger closures.
+ #
+ # Although it is not guaranteed, items added to item_list last will tend to
+ # close over fewer items, and therefore be easier to compute. By forcibly
+ # re-calculating closures from last to first, and adding the results to
+ # _closure_of_item_cache at each step, we get a modest performance
+ # improvement: roughly 50% less time spent in _closure_of_item, which
+ # translates to about 5% less time in parser().
+ for item in item_list[::-1]:
+ self._closure_of_item(item)
+ return item_set
- _parallel_goto is used in lieu of the single-symbol GOTO from ALSU because
- it eliminates the outer loop over self.terminals, and thereby reduces the
- number of next_symbol calls by a factor of len(self.terminals).
+ def _parallel_goto(self, items):
+ """The GOTO function from ALSU p261, executed on all symbols.
- Args:
- items: The set of items representing the initial DFA state.
+ _parallel_goto takes a set of Items, and returns a dict from every symbol in
+ self.symbols to the set of Items that would be active after a shift
+ operation (if symbol is a terminal) or after a reduction operation (if
+ symbol is a nonterminal).
- Returns:
- A dict from symbols to sets of items representing the new DFA states.
- """
- results = collections.defaultdict(set)
- for item in items:
- next_symbol = item.next_symbol
- if next_symbol is None:
- continue
- item = self._item_cache[item.production, item.dot + 1, item.terminal]
- # Inlining the cache check results in a ~25% speedup in this function, and
- # about 10% overall speedup to parser().
- if item in self._closure_of_item_cache:
- closure = self._closure_of_item_cache[item]
- else:
- closure = self._closure_of_item(item)
- # _closure will add newly-started Items (Items with dot=0) to the result
- # set. After this operation, the result set will correspond to the new
- # state.
- results[next_symbol].update(closure)
- return results
+ _parallel_goto is used in lieu of the single-symbol GOTO from ALSU because
+ it eliminates the outer loop over self.terminals, and thereby reduces the
+ number of next_symbol calls by a factor of len(self.terminals).
- def _items(self):
- """The items function from ALSU p261.
+ Args:
+ items: The set of items representing the initial DFA state.
- _items computes the set of sets of LR(1) items for a shift-reduce parser
- that matches the grammar. Each set of LR(1) items corresponds to a single
- DFA state.
+ Returns:
+ A dict from symbols to sets of items representing the new DFA states.
+ """
+ results = collections.defaultdict(set)
+ for item in items:
+ next_symbol = item.next_symbol
+ if next_symbol is None:
+ continue
+ item = self._item_cache[item.production, item.dot + 1, item.terminal]
+ # Inlining the cache check results in a ~25% speedup in this function, and
+ # about 10% overall speedup to parser().
+ if item in self._closure_of_item_cache:
+ closure = self._closure_of_item_cache[item]
+ else:
+ closure = self._closure_of_item(item)
+ # _closure will add newly-started Items (Items with dot=0) to the result
+ # set. After this operation, the result set will correspond to the new
+ # state.
+ results[next_symbol].update(closure)
+ return results
- Returns:
- A tuple.
+ def _items(self):
+ """The items function from ALSU p261.
- The first element of the tuple is a list of sets of LR(1) items (each set
- corresponding to a DFA state).
+ _items computes the set of sets of LR(1) items for a shift-reduce parser
+ that matches the grammar. Each set of LR(1) items corresponds to a single
+ DFA state.
- The second element of the tuple is a dictionary from (int, symbol) pairs
- to ints, where all the ints are indexes into the list of sets of LR(1)
- items. This dictionary is based on the results of the _Goto function,
- where item_sets[dict[i, sym]] == self._Goto(item_sets[i], sym).
- """
- # The list of states is seeded with the marker S' production.
- item_list = [
- frozenset(self._closure_of_item(
- self._item_cache[self._seed_production, 0, END_OF_INPUT]))
- ]
- items = {item_list[0]: 0}
- goto_table = {}
- i = 0
- # For each state, figure out what the new state when each symbol is added to
- # the top of the parsing stack (see the comments in parser._parse). See
- # _Goto for an explanation of how that is actually computed.
- while i < len(item_list):
- item_set = item_list[i]
- gotos = self._parallel_goto(item_set)
- for symbol, goto in gotos.items():
- goto = frozenset(goto)
- if goto not in items:
- items[goto] = len(item_list)
- item_list.append(goto)
- goto_table[i, symbol] = items[goto]
- i += 1
- return item_list, goto_table
+ Returns:
+ A tuple.
- def parser(self):
- """parser returns an LR(1) parser for the Grammar.
+ The first element of the tuple is a list of sets of LR(1) items (each set
+ corresponding to a DFA state).
- This implements the Canonical LR(1) ("LR(1)") parser algorithm ("Algorithm
- 4.56", ALSU p265), rather than the more common Lookahead LR(1) ("LALR(1)")
- algorithm. LALR(1) produces smaller tables, but is more complex and does
- not cover all LR(1) grammars. When the LR(1) and LALR(1) algorithms were
- invented, table sizes were an important consideration; now, the difference
- between a few hundred and a few thousand entries is unlikely to matter.
+ The second element of the tuple is a dictionary from (int, symbol) pairs
+ to ints, where all the ints are indexes into the list of sets of LR(1)
+ items. This dictionary is based on the results of the _Goto function,
+ where item_sets[dict[i, sym]] == self._Goto(item_sets[i], sym).
+ """
+ # The list of states is seeded with the marker S' production.
+ item_list = [
+ frozenset(
+ self._closure_of_item(
+ self._item_cache[self._seed_production, 0, END_OF_INPUT]
+ )
+ )
+ ]
+ items = {item_list[0]: 0}
+ goto_table = {}
+ i = 0
+ # For each state, figure out what the new state when each symbol is added to
+ # the top of the parsing stack (see the comments in parser._parse). See
+ # _Goto for an explanation of how that is actually computed.
+ while i < len(item_list):
+ item_set = item_list[i]
+ gotos = self._parallel_goto(item_set)
+ for symbol, goto in gotos.items():
+ goto = frozenset(goto)
+ if goto not in items:
+ items[goto] = len(item_list)
+ item_list.append(goto)
+ goto_table[i, symbol] = items[goto]
+ i += 1
+ return item_list, goto_table
- At this time, Grammar does not handle ambiguous grammars, which are commonly
- used to handle precedence, associativity, and the "dangling else" problem.
- Formally, these can always be handled by an unambiguous grammar, though
- doing so can be cumbersome, particularly for expression languages with many
- levels of precedence. ALSU section 4.8 (pp278-287) contains some techniques
- for handling these kinds of ambiguity.
+ def parser(self):
+ """parser returns an LR(1) parser for the Grammar.
- Returns:
- A Parser.
- """
- item_sets, goto = self._items()
- action = {}
- conflicts = set()
- end_item = self._item_cache[self._seed_production, 1, END_OF_INPUT]
- for i in range(len(item_sets)):
- for item in item_sets[i]:
- new_action = None
- if (item.next_symbol is None and
- item.production != self._seed_production):
- terminal = item.terminal
- new_action = Reduce(item.production)
- elif item.next_symbol in self.terminals:
- terminal = item.next_symbol
- assert goto[i, terminal] is not None
- new_action = Shift(goto[i, terminal], item_sets[goto[i, terminal]])
- if new_action:
- if (i, terminal) in action and action[i, terminal] != new_action:
- conflicts.add(
- Conflict(i, terminal,
- frozenset([action[i, terminal], new_action])))
- action[i, terminal] = new_action
- if item == end_item:
- new_action = Accept()
- assert (i, END_OF_INPUT
- ) not in action or action[i, END_OF_INPUT] == new_action
- action[i, END_OF_INPUT] = new_action
- trimmed_goto = {}
- for k in goto:
- if k[1] in self.nonterminals:
- trimmed_goto[k] = goto[k]
- expected = {}
- for state, terminal in action:
- if state not in expected:
- expected[state] = set()
- expected[state].add(terminal)
- return Parser(item_sets, trimmed_goto, action, expected, conflicts,
- self.terminals, self.nonterminals, self.productions)
+ This implements the Canonical LR(1) ("LR(1)") parser algorithm ("Algorithm
+ 4.56", ALSU p265), rather than the more common Lookahead LR(1) ("LALR(1)")
+ algorithm. LALR(1) produces smaller tables, but is more complex and does
+ not cover all LR(1) grammars. When the LR(1) and LALR(1) algorithms were
+ invented, table sizes were an important consideration; now, the difference
+ between a few hundred and a few thousand entries is unlikely to matter.
+
+ At this time, Grammar does not handle ambiguous grammars, which are commonly
+ used to handle precedence, associativity, and the "dangling else" problem.
+ Formally, these can always be handled by an unambiguous grammar, though
+ doing so can be cumbersome, particularly for expression languages with many
+ levels of precedence. ALSU section 4.8 (pp278-287) contains some techniques
+ for handling these kinds of ambiguity.
+
+ Returns:
+ A Parser.
+ """
+ item_sets, goto = self._items()
+ action = {}
+ conflicts = set()
+ end_item = self._item_cache[self._seed_production, 1, END_OF_INPUT]
+ for i in range(len(item_sets)):
+ for item in item_sets[i]:
+ new_action = None
+ if (
+ item.next_symbol is None
+ and item.production != self._seed_production
+ ):
+ terminal = item.terminal
+ new_action = Reduce(item.production)
+ elif item.next_symbol in self.terminals:
+ terminal = item.next_symbol
+ assert goto[i, terminal] is not None
+ new_action = Shift(goto[i, terminal], item_sets[goto[i, terminal]])
+ if new_action:
+ if (i, terminal) in action and action[i, terminal] != new_action:
+ conflicts.add(
+ Conflict(
+ i,
+ terminal,
+ frozenset([action[i, terminal], new_action]),
+ )
+ )
+ action[i, terminal] = new_action
+ if item == end_item:
+ new_action = Accept()
+ assert (i, END_OF_INPUT) not in action or action[
+ i, END_OF_INPUT
+ ] == new_action
+ action[i, END_OF_INPUT] = new_action
+ trimmed_goto = {}
+ for k in goto:
+ if k[1] in self.nonterminals:
+ trimmed_goto[k] = goto[k]
+ expected = {}
+ for state, terminal in action:
+ if state not in expected:
+ expected[state] = set()
+ expected[state].add(terminal)
+ return Parser(
+ item_sets,
+ trimmed_goto,
+ action,
+ expected,
+ conflicts,
+ self.terminals,
+ self.nonterminals,
+ self.productions,
+ )
-ParseError = collections.namedtuple("ParseError", ["code", "index", "token",
- "state", "expected_tokens"])
+ParseError = collections.namedtuple(
+ "ParseError", ["code", "index", "token", "state", "expected_tokens"]
+)
ParseResult = collections.namedtuple("ParseResult", ["parse_tree", "error"])
class Parser(object):
- """Parser is a shift-reduce LR(1) parser.
+ """Parser is a shift-reduce LR(1) parser.
- Generally, clients will want to get a Parser from a Grammar, rather than
- directly instantiating one.
+ Generally, clients will want to get a Parser from a Grammar, rather than
+ directly instantiating one.
- Parser exposes the raw tables needed to feed into a Shift-Reduce parser,
- but can also be used directly for parsing.
+ Parser exposes the raw tables needed to feed into a Shift-Reduce parser,
+ but can also be used directly for parsing.
- Attributes:
- item_sets: A list of item sets which correspond to the state numbers in
- the action and goto tables. This is not necessary for parsing, but is
- useful for debugging parsers.
- goto: The GOTO table for this parser.
- action: The ACTION table for this parser.
- expected: A table of terminal symbols that are expected (that is, that
- have a non-Error action) for each state. This can be used to provide
- more helpful error messages for parse errors.
- conflicts: A set of unresolved conflicts found during table generation.
- terminals: A set of terminal symbols in the grammar.
- nonterminals: A set of nonterminal symbols in the grammar.
- productions: A list of productions in the grammar.
- default_errors: A dict of states to default error codes to use when
- encountering an error in that state, when a more-specific Error for the
- state/terminal pair has not been set.
- """
-
- def __init__(self, item_sets, goto, action, expected, conflicts, terminals,
- nonterminals, productions):
- super(Parser, self).__init__()
- self.item_sets = item_sets
- self.goto = goto
- self.action = action
- self.expected = expected
- self.conflicts = conflicts
- self.terminals = terminals
- self.nonterminals = nonterminals
- self.productions = productions
- self.default_errors = {}
-
- def _parse(self, tokens):
- """_parse implements Shift-Reduce parsing algorithm.
-
- _parse implements the standard shift-reduce algorithm outlined on ASLU
- pp236-237.
-
- Arguments:
- tokens: the list of token objects to parse.
-
- Returns:
- A ParseResult.
+ Attributes:
+ item_sets: A list of item sets which correspond to the state numbers in
+ the action and goto tables. This is not necessary for parsing, but is
+ useful for debugging parsers.
+ goto: The GOTO table for this parser.
+ action: The ACTION table for this parser.
+ expected: A table of terminal symbols that are expected (that is, that
+ have a non-Error action) for each state. This can be used to provide
+ more helpful error messages for parse errors.
+ conflicts: A set of unresolved conflicts found during table generation.
+ terminals: A set of terminal symbols in the grammar.
+ nonterminals: A set of nonterminal symbols in the grammar.
+ productions: A list of productions in the grammar.
+ default_errors: A dict of states to default error codes to use when
+ encountering an error in that state, when a more-specific Error for the
+ state/terminal pair has not been set.
"""
- # The END_OF_INPUT token is explicitly added to avoid explicit "cursor <
- # len(tokens)" checks.
- tokens = list(tokens) + [Symbol(END_OF_INPUT)]
- # Each element of stack is a parse state and a (possibly partial) parse
- # tree. The state at the top of the stack encodes which productions are
- # "active" (that is, which ones the parser has seen partial input which
- # matches some prefix of the production, in a place where that production
- # might be valid), and, for each active production, how much of the
- # production has been completed.
- stack = [(0, None)]
+ def __init__(
+ self,
+ item_sets,
+ goto,
+ action,
+ expected,
+ conflicts,
+ terminals,
+ nonterminals,
+ productions,
+ ):
+ super(Parser, self).__init__()
+ self.item_sets = item_sets
+ self.goto = goto
+ self.action = action
+ self.expected = expected
+ self.conflicts = conflicts
+ self.terminals = terminals
+ self.nonterminals = nonterminals
+ self.productions = productions
+ self.default_errors = {}
- def state():
- return stack[-1][0]
+ def _parse(self, tokens):
+ """_parse implements Shift-Reduce parsing algorithm.
- cursor = 0
+ _parse implements the standard shift-reduce algorithm outlined on ASLU
+ pp236-237.
- # On each iteration, look at the next symbol and the current state, and
- # perform the corresponding action.
- while True:
- if (state(), tokens[cursor].symbol) not in self.action:
- # Most state/symbol entries would be Errors, so rather than exhaustively
- # adding error entries, we just check here.
- if state() in self.default_errors:
- next_action = Error(self.default_errors[state()])
+ Arguments:
+ tokens: the list of token objects to parse.
+
+ Returns:
+ A ParseResult.
+ """
+ # The END_OF_INPUT token is explicitly added to avoid explicit "cursor <
+ # len(tokens)" checks.
+ tokens = list(tokens) + [Symbol(END_OF_INPUT)]
+
+ # Each element of stack is a parse state and a (possibly partial) parse
+ # tree. The state at the top of the stack encodes which productions are
+ # "active" (that is, which ones the parser has seen partial input which
+ # matches some prefix of the production, in a place where that production
+ # might be valid), and, for each active production, how much of the
+ # production has been completed.
+ stack = [(0, None)]
+
+ def state():
+ return stack[-1][0]
+
+ cursor = 0
+
+ # On each iteration, look at the next symbol and the current state, and
+ # perform the corresponding action.
+ while True:
+ if (state(), tokens[cursor].symbol) not in self.action:
+ # Most state/symbol entries would be Errors, so rather than exhaustively
+ # adding error entries, we just check here.
+ if state() in self.default_errors:
+ next_action = Error(self.default_errors[state()])
+ else:
+ next_action = Error(None)
+ else:
+ next_action = self.action[state(), tokens[cursor].symbol]
+
+ if isinstance(next_action, Shift):
+ # Shift means that there are no "complete" productions on the stack,
+ # and so the current token should be shifted onto the stack, with a new
+ # state indicating the new set of "active" productions.
+ stack.append((next_action.state, tokens[cursor]))
+ cursor += 1
+ elif isinstance(next_action, Accept):
+ # Accept means that parsing is over, successfully.
+ assert len(stack) == 2, "Accepted incompletely-reduced input."
+ assert tokens[cursor].symbol == END_OF_INPUT, (
+ "Accepted parse before " "end of input."
+ )
+ return ParseResult(stack[-1][1], None)
+ elif isinstance(next_action, Reduce):
+ # Reduce means that there is a complete production on the stack, and
+ # that the next symbol implies that the completed production is the
+ # correct production.
+ #
+ # Per ALSU, we would simply pop an element off the state stack for each
+ # symbol on the rhs of the production, and then push a new state by
+ # looking up the (post-pop) current state and the lhs of the production
+ # in GOTO. The GOTO table, in some sense, is equivalent to shift
+ # actions for nonterminal symbols.
+ #
+ # Here, we attach a new partial parse tree, with the production lhs as
+ # the "name" of the tree, and the popped trees as the "children" of the
+ # new tree.
+ children = [
+ item[1] for item in stack[len(stack) - len(next_action.rule.rhs) :]
+ ]
+ # Attach source_location, if known. The source location will not be
+ # known if the reduction consumes no symbols (empty rhs) or if the
+ # client did not specify source_locations for tokens.
+ #
+ # It is necessary to loop in order to handle cases like:
+ #
+ # C -> c D
+ # D ->
+ #
+ # The D child of the C reduction will not have a source location
+ # (because it is not produced from any source), so it is necessary to
+ # scan backwards through C's children to find the end position. The
+ # opposite is required in the case where initial children have no
+ # source.
+ #
+ # These loops implicitly handle the case where the reduction has no
+ # children, setting the source_location to None in that case.
+ start_position = None
+ end_position = None
+ for child in children:
+ if (
+ hasattr(child, "source_location")
+ and child.source_location is not None
+ ):
+ start_position = child.source_location.start
+ break
+ for child in reversed(children):
+ if (
+ hasattr(child, "source_location")
+ and child.source_location is not None
+ ):
+ end_position = child.source_location.end
+ break
+ if start_position is None:
+ source_location = None
+ else:
+ source_location = parser_types.make_location(
+ start_position, end_position
+ )
+ reduction = Reduction(
+ next_action.rule.lhs, children, next_action.rule, source_location
+ )
+ del stack[len(stack) - len(next_action.rule.rhs) :]
+ stack.append((self.goto[state(), next_action.rule.lhs], reduction))
+ elif isinstance(next_action, Error):
+ # Error means that the parse is impossible. For typical grammars and
+ # texts, this usually happens within a few tokens after the mistake in
+ # the input stream, which is convenient (though imperfect) for error
+ # reporting.
+ return ParseResult(
+ None,
+ ParseError(
+ next_action.code,
+ cursor,
+ tokens[cursor],
+ state(),
+ self.expected[state()],
+ ),
+ )
+ else:
+ assert False, "Shouldn't be here."
+
+ def mark_error(self, tokens, error_token, error_code):
+ """Marks an error state with the given error code.
+
+ mark_error implements the equivalent of the "Merr" system presented in
+ "Generating LR Syntax error Messages from Examples" (Jeffery, 2003).
+ This system has limitations, but has the primary advantage that error
+ messages can be specified by giving an example of the error and the
+ message itself.
+
+ Arguments:
+ tokens: a list of tokens to parse.
+ error_token: the token where the parse should fail, or None if the parse
+ should fail at the implicit end-of-input token.
+
+ If the error_token is the special ANY_TOKEN, then the error will be
+ recorded as the default error for the error state.
+ error_code: a value to record for the error state reached by parsing
+ tokens.
+
+ Returns:
+ None if error_code was successfully recorded, or an error message if there
+ was a problem.
+ """
+ result = self._parse(tokens)
+
+ # There is no error state to mark on a successful parse.
+ if not result.error:
+ return "Input successfully parsed."
+
+ # Check if the error occurred at the specified token; if not, then this was
+ # not the expected error.
+ if error_token is None:
+ error_symbol = END_OF_INPUT
+ if result.error.token.symbol != END_OF_INPUT:
+ return "error occurred on {} token, not end of input.".format(
+ result.error.token.symbol
+ )
else:
- next_action = Error(None)
- else:
- next_action = self.action[state(), tokens[cursor].symbol]
+ error_symbol = error_token.symbol
+ if result.error.token != error_token:
+ return "error occurred on {} token, not {} token.".format(
+ result.error.token.symbol, error_token.symbol
+ )
- if isinstance(next_action, Shift):
- # Shift means that there are no "complete" productions on the stack,
- # and so the current token should be shifted onto the stack, with a new
- # state indicating the new set of "active" productions.
- stack.append((next_action.state, tokens[cursor]))
- cursor += 1
- elif isinstance(next_action, Accept):
- # Accept means that parsing is over, successfully.
- assert len(stack) == 2, "Accepted incompletely-reduced input."
- assert tokens[cursor].symbol == END_OF_INPUT, ("Accepted parse before "
- "end of input.")
- return ParseResult(stack[-1][1], None)
- elif isinstance(next_action, Reduce):
- # Reduce means that there is a complete production on the stack, and
- # that the next symbol implies that the completed production is the
- # correct production.
- #
- # Per ALSU, we would simply pop an element off the state stack for each
- # symbol on the rhs of the production, and then push a new state by
- # looking up the (post-pop) current state and the lhs of the production
- # in GOTO. The GOTO table, in some sense, is equivalent to shift
- # actions for nonterminal symbols.
- #
- # Here, we attach a new partial parse tree, with the production lhs as
- # the "name" of the tree, and the popped trees as the "children" of the
- # new tree.
- children = [
- item[1] for item in stack[len(stack) - len(next_action.rule.rhs):]
- ]
- # Attach source_location, if known. The source location will not be
- # known if the reduction consumes no symbols (empty rhs) or if the
- # client did not specify source_locations for tokens.
- #
- # It is necessary to loop in order to handle cases like:
- #
- # C -> c D
- # D ->
- #
- # The D child of the C reduction will not have a source location
- # (because it is not produced from any source), so it is necessary to
- # scan backwards through C's children to find the end position. The
- # opposite is required in the case where initial children have no
- # source.
- #
- # These loops implicitly handle the case where the reduction has no
- # children, setting the source_location to None in that case.
- start_position = None
- end_position = None
- for child in children:
- if hasattr(child,
- "source_location") and child.source_location is not None:
- start_position = child.source_location.start
- break
- for child in reversed(children):
- if hasattr(child,
- "source_location") and child.source_location is not None:
- end_position = child.source_location.end
- break
- if start_position is None:
- source_location = None
+ # If the expected error was found, attempt to mark it. It is acceptable if
+ # the given error_code is already set as the error code for the given parse,
+ # but not if a different code is set.
+ if result.error.token == ANY_TOKEN:
+ # For ANY_TOKEN, mark it as a default error.
+ if result.error.state in self.default_errors:
+ if self.default_errors[result.error.state] == error_code:
+ return None
+ else:
+ return (
+ "Attempted to overwrite existing default error code {!r} "
+ "with new error code {!r} for state {}".format(
+ self.default_errors[result.error.state],
+ error_code,
+ result.error.state,
+ )
+ )
+ else:
+ self.default_errors[result.error.state] = error_code
+ return None
else:
- source_location = parser_types.make_location(start_position,
- end_position)
- reduction = Reduction(next_action.rule.lhs, children, next_action.rule,
- source_location)
- del stack[len(stack) - len(next_action.rule.rhs):]
- stack.append((self.goto[state(), next_action.rule.lhs], reduction))
- elif isinstance(next_action, Error):
- # Error means that the parse is impossible. For typical grammars and
- # texts, this usually happens within a few tokens after the mistake in
- # the input stream, which is convenient (though imperfect) for error
- # reporting.
- return ParseResult(None,
- ParseError(next_action.code, cursor, tokens[cursor],
- state(), self.expected[state()]))
- else:
- assert False, "Shouldn't be here."
+ if (result.error.state, error_symbol) in self.action:
+ existing_error = self.action[result.error.state, error_symbol]
+ assert isinstance(existing_error, Error), "Bug"
+ if existing_error.code == error_code:
+ return None
+ else:
+ return (
+ "Attempted to overwrite existing error code {!r} with new "
+ "error code {!r} for state {}, terminal {}".format(
+ existing_error.code,
+ error_code,
+ result.error.state,
+ error_symbol,
+ )
+ )
+ else:
+ self.action[result.error.state, error_symbol] = Error(error_code)
+ return None
+ assert False, "All other paths should lead to return."
- def mark_error(self, tokens, error_token, error_code):
- """Marks an error state with the given error code.
+ def parse(self, tokens):
+ """Parses a list of tokens.
- mark_error implements the equivalent of the "Merr" system presented in
- "Generating LR Syntax error Messages from Examples" (Jeffery, 2003).
- This system has limitations, but has the primary advantage that error
- messages can be specified by giving an example of the error and the
- message itself.
+ Arguments:
+ tokens: a list of tokens to parse.
- Arguments:
- tokens: a list of tokens to parse.
- error_token: the token where the parse should fail, or None if the parse
- should fail at the implicit end-of-input token.
-
- If the error_token is the special ANY_TOKEN, then the error will be
- recorded as the default error for the error state.
- error_code: a value to record for the error state reached by parsing
- tokens.
-
- Returns:
- None if error_code was successfully recorded, or an error message if there
- was a problem.
- """
- result = self._parse(tokens)
-
- # There is no error state to mark on a successful parse.
- if not result.error:
- return "Input successfully parsed."
-
- # Check if the error occurred at the specified token; if not, then this was
- # not the expected error.
- if error_token is None:
- error_symbol = END_OF_INPUT
- if result.error.token.symbol != END_OF_INPUT:
- return "error occurred on {} token, not end of input.".format(
- result.error.token.symbol)
- else:
- error_symbol = error_token.symbol
- if result.error.token != error_token:
- return "error occurred on {} token, not {} token.".format(
- result.error.token.symbol, error_token.symbol)
-
- # If the expected error was found, attempt to mark it. It is acceptable if
- # the given error_code is already set as the error code for the given parse,
- # but not if a different code is set.
- if result.error.token == ANY_TOKEN:
- # For ANY_TOKEN, mark it as a default error.
- if result.error.state in self.default_errors:
- if self.default_errors[result.error.state] == error_code:
- return None
- else:
- return ("Attempted to overwrite existing default error code {!r} "
- "with new error code {!r} for state {}".format(
- self.default_errors[result.error.state], error_code,
- result.error.state))
- else:
- self.default_errors[result.error.state] = error_code
- return None
- else:
- if (result.error.state, error_symbol) in self.action:
- existing_error = self.action[result.error.state, error_symbol]
- assert isinstance(existing_error, Error), "Bug"
- if existing_error.code == error_code:
- return None
- else:
- return ("Attempted to overwrite existing error code {!r} with new "
- "error code {!r} for state {}, terminal {}".format(
- existing_error.code, error_code, result.error.state,
- error_symbol))
- else:
- self.action[result.error.state, error_symbol] = Error(error_code)
- return None
- assert False, "All other paths should lead to return."
-
- def parse(self, tokens):
- """Parses a list of tokens.
-
- Arguments:
- tokens: a list of tokens to parse.
-
- Returns:
- A ParseResult.
- """
- result = self._parse(tokens)
- return result
+ Returns:
+ A ParseResult.
+ """
+ result = self._parse(tokens)
+ return result
diff --git a/compiler/front_end/lr1_test.py b/compiler/front_end/lr1_test.py
index 44ffa75..ae03e2d 100644
--- a/compiler/front_end/lr1_test.py
+++ b/compiler/front_end/lr1_test.py
@@ -22,58 +22,77 @@
def _make_items(text):
- """Makes a list of lr1.Items from the lines in text."""
- return frozenset([lr1.Item.parse(line.strip()) for line in text.splitlines()])
+ """Makes a list of lr1.Items from the lines in text."""
+ return frozenset([lr1.Item.parse(line.strip()) for line in text.splitlines()])
Token = collections.namedtuple("Token", ["symbol", "source_location"])
def _tokenize(text):
- """"Tokenizes" text by making each character into a token."""
- result = []
- for i in range(len(text)):
- result.append(Token(text[i], parser_types.make_location(
- (1, i + 1), (1, i + 2))))
- return result
+ """ "Tokenizes" text by making each character into a token."""
+ result = []
+ for i in range(len(text)):
+ result.append(
+ Token(text[i], parser_types.make_location((1, i + 1), (1, i + 2)))
+ )
+ return result
def _parse_productions(text):
- """Parses text into a grammar by calling Production.parse on each line."""
- return [parser_types.Production.parse(line) for line in text.splitlines()]
+ """Parses text into a grammar by calling Production.parse on each line."""
+ return [parser_types.Production.parse(line) for line in text.splitlines()]
+
# Example grammar 4.54 from Aho, Sethi, Lam, Ullman (ASLU) p263.
-_alsu_grammar = lr1.Grammar("S", _parse_productions("""S -> C C
+_alsu_grammar = lr1.Grammar(
+ "S",
+ _parse_productions(
+ """S -> C C
C -> c C
- C -> d"""))
+ C -> d"""
+ ),
+)
# Item sets corresponding to the above grammar, ASLU pp263-264.
_alsu_items = [
- _make_items("""S' -> . S, $
+ _make_items(
+ """S' -> . S, $
S -> . C C, $
C -> . c C, c
C -> . c C, d
C -> . d, c
- C -> . d, d"""),
+ C -> . d, d"""
+ ),
_make_items("""S' -> S ., $"""),
- _make_items("""S -> C . C, $
+ _make_items(
+ """S -> C . C, $
C -> . c C, $
- C -> . d, $"""),
- _make_items("""C -> c . C, c
+ C -> . d, $"""
+ ),
+ _make_items(
+ """C -> c . C, c
C -> c . C, d
C -> . c C, c
C -> . c C, d
C -> . d, c
- C -> . d, d"""),
- _make_items("""C -> d ., c
- C -> d ., d"""),
+ C -> . d, d"""
+ ),
+ _make_items(
+ """C -> d ., c
+ C -> d ., d"""
+ ),
_make_items("""S -> C C ., $"""),
- _make_items("""C -> c . C, $
+ _make_items(
+ """C -> c . C, $
C -> . c C, $
- C -> . d, $"""),
+ C -> . d, $"""
+ ),
_make_items("""C -> d ., $"""),
- _make_items("""C -> c C ., c
- C -> c C ., d"""),
+ _make_items(
+ """C -> c C ., c
+ C -> c C ., d"""
+ ),
_make_items("""C -> c C ., $"""),
]
@@ -98,220 +117,315 @@
}
# GOTO table corresponding to the above grammar, ASLU p266.
-_alsu_goto = {(0, "S"): 1, (0, "C"): 2, (2, "C"): 5, (3, "C"): 8, (6, "C"): 9,}
+_alsu_goto = {
+ (0, "S"): 1,
+ (0, "C"): 2,
+ (2, "C"): 5,
+ (3, "C"): 8,
+ (6, "C"): 9,
+}
def _normalize_table(items, table):
- """Returns a canonical-form version of items and table, for comparisons."""
- item_to_original_index = {}
- for i in range(len(items)):
- item_to_original_index[items[i]] = i
- sorted_items = items[0:1] + sorted(items[1:], key=sorted)
- original_index_to_index = {}
- for i in range(len(sorted_items)):
- original_index_to_index[item_to_original_index[sorted_items[i]]] = i
- updated_table = {}
- for k in table:
- new_k = original_index_to_index[k[0]], k[1]
- new_value = table[k]
- if isinstance(new_value, int):
- new_value = original_index_to_index[new_value]
- elif isinstance(new_value, lr1.Shift):
- new_value = lr1.Shift(original_index_to_index[new_value.state],
- new_value.items)
- updated_table[new_k] = new_value
- return sorted_items, updated_table
+ """Returns a canonical-form version of items and table, for comparisons."""
+ item_to_original_index = {}
+ for i in range(len(items)):
+ item_to_original_index[items[i]] = i
+ sorted_items = items[0:1] + sorted(items[1:], key=sorted)
+ original_index_to_index = {}
+ for i in range(len(sorted_items)):
+ original_index_to_index[item_to_original_index[sorted_items[i]]] = i
+ updated_table = {}
+ for k in table:
+ new_k = original_index_to_index[k[0]], k[1]
+ new_value = table[k]
+ if isinstance(new_value, int):
+ new_value = original_index_to_index[new_value]
+ elif isinstance(new_value, lr1.Shift):
+ new_value = lr1.Shift(
+ original_index_to_index[new_value.state], new_value.items
+ )
+ updated_table[new_k] = new_value
+ return sorted_items, updated_table
class Lr1Test(unittest.TestCase):
- """Tests for lr1."""
+ """Tests for lr1."""
- def test_parse_lr1item(self):
- self.assertEqual(lr1.Item.parse("S' -> . S, $"),
- lr1.Item(parser_types.Production(lr1.START_PRIME, ("S",)),
- 0, lr1.END_OF_INPUT, "S"))
+ def test_parse_lr1item(self):
+ self.assertEqual(
+ lr1.Item.parse("S' -> . S, $"),
+ lr1.Item(
+ parser_types.Production(lr1.START_PRIME, ("S",)),
+ 0,
+ lr1.END_OF_INPUT,
+ "S",
+ ),
+ )
- def test_symbol_extraction(self):
- self.assertEqual(_alsu_grammar.terminals, set(["c", "d", lr1.END_OF_INPUT]))
- self.assertEqual(_alsu_grammar.nonterminals, set(["S", "C",
- lr1.START_PRIME]))
- self.assertEqual(_alsu_grammar.symbols,
- set(["c", "d", "S", "C", lr1.END_OF_INPUT,
- lr1.START_PRIME]))
+ def test_symbol_extraction(self):
+ self.assertEqual(_alsu_grammar.terminals, set(["c", "d", lr1.END_OF_INPUT]))
+ self.assertEqual(_alsu_grammar.nonterminals, set(["S", "C", lr1.START_PRIME]))
+ self.assertEqual(
+ _alsu_grammar.symbols,
+ set(["c", "d", "S", "C", lr1.END_OF_INPUT, lr1.START_PRIME]),
+ )
- def test_items(self):
- self.assertEqual(set(_alsu_grammar._items()[0]), frozenset(_alsu_items))
+ def test_items(self):
+ self.assertEqual(set(_alsu_grammar._items()[0]), frozenset(_alsu_items))
- def test_terminal_nonterminal_production_tables(self):
- parser = _alsu_grammar.parser()
- self.assertEqual(parser.terminals, _alsu_grammar.terminals)
- self.assertEqual(parser.nonterminals, _alsu_grammar.nonterminals)
- self.assertEqual(parser.productions, _alsu_grammar.productions)
+ def test_terminal_nonterminal_production_tables(self):
+ parser = _alsu_grammar.parser()
+ self.assertEqual(parser.terminals, _alsu_grammar.terminals)
+ self.assertEqual(parser.nonterminals, _alsu_grammar.nonterminals)
+ self.assertEqual(parser.productions, _alsu_grammar.productions)
- def test_action_table(self):
- parser = _alsu_grammar.parser()
- norm_items, norm_action = _normalize_table(parser.item_sets, parser.action)
- test_items, test_action = _normalize_table(_alsu_items, _alsu_action)
- self.assertEqual(norm_items, test_items)
- self.assertEqual(norm_action, test_action)
+ def test_action_table(self):
+ parser = _alsu_grammar.parser()
+ norm_items, norm_action = _normalize_table(parser.item_sets, parser.action)
+ test_items, test_action = _normalize_table(_alsu_items, _alsu_action)
+ self.assertEqual(norm_items, test_items)
+ self.assertEqual(norm_action, test_action)
- def test_goto_table(self):
- parser = _alsu_grammar.parser()
- norm_items, norm_goto = _normalize_table(parser.item_sets, parser.goto)
- test_items, test_goto = _normalize_table(_alsu_items, _alsu_goto)
- self.assertEqual(norm_items, test_items)
- self.assertEqual(norm_goto, test_goto)
+ def test_goto_table(self):
+ parser = _alsu_grammar.parser()
+ norm_items, norm_goto = _normalize_table(parser.item_sets, parser.goto)
+ test_items, test_goto = _normalize_table(_alsu_items, _alsu_goto)
+ self.assertEqual(norm_items, test_items)
+ self.assertEqual(norm_goto, test_goto)
- def test_successful_parse(self):
- parser = _alsu_grammar.parser()
- loc = parser_types.parse_location
- s_to_c_c = parser_types.Production.parse("S -> C C")
- c_to_c_c = parser_types.Production.parse("C -> c C")
- c_to_d = parser_types.Production.parse("C -> d")
- self.assertEqual(
- lr1.Reduction("S", [lr1.Reduction("C", [
- Token("c", loc("1:1-1:2")), lr1.Reduction(
- "C", [Token("c", loc("1:2-1:3")),
- lr1.Reduction("C",
- [Token("c", loc("1:3-1:4")), lr1.Reduction(
- "C", [Token("d", loc("1:4-1:5"))],
- c_to_d, loc("1:4-1:5"))], c_to_c_c,
- loc("1:3-1:5"))], c_to_c_c, loc("1:2-1:5"))
- ], c_to_c_c, loc("1:1-1:5")), lr1.Reduction(
- "C", [Token("c", loc("1:5-1:6")),
- lr1.Reduction("C", [Token("d", loc("1:6-1:7"))], c_to_d,
- loc("1:6-1:7"))], c_to_c_c, loc("1:5-1:7"))],
- s_to_c_c, loc("1:1-1:7")),
- parser.parse(_tokenize("cccdcd")).parse_tree)
- self.assertEqual(
- lr1.Reduction("S", [
- lr1.Reduction("C", [Token("d", loc("1:1-1:2"))], c_to_d, loc(
- "1:1-1:2")), lr1.Reduction("C", [Token("d", loc("1:2-1:3"))],
- c_to_d, loc("1:2-1:3"))
- ], s_to_c_c, loc("1:1-1:3")), parser.parse(_tokenize("dd")).parse_tree)
+ def test_successful_parse(self):
+ parser = _alsu_grammar.parser()
+ loc = parser_types.parse_location
+ s_to_c_c = parser_types.Production.parse("S -> C C")
+ c_to_c_c = parser_types.Production.parse("C -> c C")
+ c_to_d = parser_types.Production.parse("C -> d")
+ self.assertEqual(
+ lr1.Reduction(
+ "S",
+ [
+ lr1.Reduction(
+ "C",
+ [
+ Token("c", loc("1:1-1:2")),
+ lr1.Reduction(
+ "C",
+ [
+ Token("c", loc("1:2-1:3")),
+ lr1.Reduction(
+ "C",
+ [
+ Token("c", loc("1:3-1:4")),
+ lr1.Reduction(
+ "C",
+ [Token("d", loc("1:4-1:5"))],
+ c_to_d,
+ loc("1:4-1:5"),
+ ),
+ ],
+ c_to_c_c,
+ loc("1:3-1:5"),
+ ),
+ ],
+ c_to_c_c,
+ loc("1:2-1:5"),
+ ),
+ ],
+ c_to_c_c,
+ loc("1:1-1:5"),
+ ),
+ lr1.Reduction(
+ "C",
+ [
+ Token("c", loc("1:5-1:6")),
+ lr1.Reduction(
+ "C",
+ [Token("d", loc("1:6-1:7"))],
+ c_to_d,
+ loc("1:6-1:7"),
+ ),
+ ],
+ c_to_c_c,
+ loc("1:5-1:7"),
+ ),
+ ],
+ s_to_c_c,
+ loc("1:1-1:7"),
+ ),
+ parser.parse(_tokenize("cccdcd")).parse_tree,
+ )
+ self.assertEqual(
+ lr1.Reduction(
+ "S",
+ [
+ lr1.Reduction(
+ "C", [Token("d", loc("1:1-1:2"))], c_to_d, loc("1:1-1:2")
+ ),
+ lr1.Reduction(
+ "C", [Token("d", loc("1:2-1:3"))], c_to_d, loc("1:2-1:3")
+ ),
+ ],
+ s_to_c_c,
+ loc("1:1-1:3"),
+ ),
+ parser.parse(_tokenize("dd")).parse_tree,
+ )
- def test_parse_with_no_source_information(self):
- parser = _alsu_grammar.parser()
- s_to_c_c = parser_types.Production.parse("S -> C C")
- c_to_d = parser_types.Production.parse("C -> d")
- self.assertEqual(
- lr1.Reduction("S", [
- lr1.Reduction("C", [Token("d", None)], c_to_d, None),
- lr1.Reduction("C", [Token("d", None)], c_to_d, None)
- ], s_to_c_c, None),
- parser.parse([Token("d", None), Token("d", None)]).parse_tree)
+ def test_parse_with_no_source_information(self):
+ parser = _alsu_grammar.parser()
+ s_to_c_c = parser_types.Production.parse("S -> C C")
+ c_to_d = parser_types.Production.parse("C -> d")
+ self.assertEqual(
+ lr1.Reduction(
+ "S",
+ [
+ lr1.Reduction("C", [Token("d", None)], c_to_d, None),
+ lr1.Reduction("C", [Token("d", None)], c_to_d, None),
+ ],
+ s_to_c_c,
+ None,
+ ),
+ parser.parse([Token("d", None), Token("d", None)]).parse_tree,
+ )
- def test_failed_parses(self):
- parser = _alsu_grammar.parser()
- self.assertEqual(None, parser.parse(_tokenize("d")).parse_tree)
- self.assertEqual(None, parser.parse(_tokenize("cccd")).parse_tree)
- self.assertEqual(None, parser.parse(_tokenize("")).parse_tree)
- self.assertEqual(None, parser.parse(_tokenize("cccdc")).parse_tree)
+ def test_failed_parses(self):
+ parser = _alsu_grammar.parser()
+ self.assertEqual(None, parser.parse(_tokenize("d")).parse_tree)
+ self.assertEqual(None, parser.parse(_tokenize("cccd")).parse_tree)
+ self.assertEqual(None, parser.parse(_tokenize("")).parse_tree)
+ self.assertEqual(None, parser.parse(_tokenize("cccdc")).parse_tree)
- def test_mark_error(self):
- parser = _alsu_grammar.parser()
- self.assertIsNone(parser.mark_error(_tokenize("cccdc"), None,
- "missing last d"))
- self.assertIsNone(parser.mark_error(_tokenize("d"), None, "missing last C"))
- # Marking an already-marked error with the same error code should succeed.
- self.assertIsNone(parser.mark_error(_tokenize("d"), None, "missing last C"))
- # Marking an already-marked error with a different error code should fail.
- self.assertRegexpMatches(
- parser.mark_error(_tokenize("d"), None, "different message"),
- r"^Attempted to overwrite existing error code 'missing last C' with "
- r"new error code 'different message' for state \d+, terminal \$$")
- self.assertEqual(
- "Input successfully parsed.",
- parser.mark_error(_tokenize("dd"), None, "good parse"))
- self.assertEqual(
- parser.mark_error(_tokenize("x"), None, "wrong location"),
- "error occurred on x token, not end of input.")
- self.assertEqual(
- parser.mark_error([], _tokenize("x")[0], "wrong location"),
- "error occurred on $ token, not x token.")
- self.assertIsNone(
- parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error"))
- # Marking an already-marked error with the same error code should succeed.
- self.assertIsNone(
- parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error"))
- # Marking an already-marked error with a different error code should fail.
- self.assertRegexpMatches(
- parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error 2"),
- r"^Attempted to overwrite existing default error code 'default error' "
- r"with new error code 'default error 2' for state \d+$")
+ def test_mark_error(self):
+ parser = _alsu_grammar.parser()
+ self.assertIsNone(parser.mark_error(_tokenize("cccdc"), None, "missing last d"))
+ self.assertIsNone(parser.mark_error(_tokenize("d"), None, "missing last C"))
+ # Marking an already-marked error with the same error code should succeed.
+ self.assertIsNone(parser.mark_error(_tokenize("d"), None, "missing last C"))
+ # Marking an already-marked error with a different error code should fail.
+ self.assertRegexpMatches(
+ parser.mark_error(_tokenize("d"), None, "different message"),
+ r"^Attempted to overwrite existing error code 'missing last C' with "
+ r"new error code 'different message' for state \d+, terminal \$$",
+ )
+ self.assertEqual(
+ "Input successfully parsed.",
+ parser.mark_error(_tokenize("dd"), None, "good parse"),
+ )
+ self.assertEqual(
+ parser.mark_error(_tokenize("x"), None, "wrong location"),
+ "error occurred on x token, not end of input.",
+ )
+ self.assertEqual(
+ parser.mark_error([], _tokenize("x")[0], "wrong location"),
+ "error occurred on $ token, not x token.",
+ )
+ self.assertIsNone(
+ parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error")
+ )
+ # Marking an already-marked error with the same error code should succeed.
+ self.assertIsNone(
+ parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error")
+ )
+ # Marking an already-marked error with a different error code should fail.
+ self.assertRegexpMatches(
+ parser.mark_error([lr1.ANY_TOKEN], lr1.ANY_TOKEN, "default error 2"),
+ r"^Attempted to overwrite existing default error code 'default error' "
+ r"with new error code 'default error 2' for state \d+$",
+ )
- self.assertEqual(
- "missing last d", parser.parse(_tokenize("cccdc")).error.code)
- self.assertEqual("missing last d", parser.parse(_tokenize("dc")).error.code)
- self.assertEqual("missing last C", parser.parse(_tokenize("d")).error.code)
- self.assertEqual("default error", parser.parse(_tokenize("z")).error.code)
- self.assertEqual(
- "missing last C", parser.parse(_tokenize("ccccd")).error.code)
- self.assertEqual(None, parser.parse(_tokenize("ccc")).error.code)
+ self.assertEqual("missing last d", parser.parse(_tokenize("cccdc")).error.code)
+ self.assertEqual("missing last d", parser.parse(_tokenize("dc")).error.code)
+ self.assertEqual("missing last C", parser.parse(_tokenize("d")).error.code)
+ self.assertEqual("default error", parser.parse(_tokenize("z")).error.code)
+ self.assertEqual("missing last C", parser.parse(_tokenize("ccccd")).error.code)
+ self.assertEqual(None, parser.parse(_tokenize("ccc")).error.code)
- def test_grammar_with_empty_rhs(self):
- grammar = lr1.Grammar("S", _parse_productions("""S -> A B
+ def test_grammar_with_empty_rhs(self):
+ grammar = lr1.Grammar(
+ "S",
+ _parse_productions(
+ """S -> A B
A -> a A
A ->
- B -> b"""))
- parser = grammar.parser()
- self.assertFalse(parser.conflicts)
- self.assertTrue(parser.parse(_tokenize("ab")).parse_tree)
- self.assertTrue(parser.parse(_tokenize("b")).parse_tree)
- self.assertTrue(parser.parse(_tokenize("aab")).parse_tree)
+ B -> b"""
+ ),
+ )
+ parser = grammar.parser()
+ self.assertFalse(parser.conflicts)
+ self.assertTrue(parser.parse(_tokenize("ab")).parse_tree)
+ self.assertTrue(parser.parse(_tokenize("b")).parse_tree)
+ self.assertTrue(parser.parse(_tokenize("aab")).parse_tree)
- def test_grammar_with_reduce_reduce_conflicts(self):
- grammar = lr1.Grammar("S", _parse_productions("""S -> A c
+ def test_grammar_with_reduce_reduce_conflicts(self):
+ grammar = lr1.Grammar(
+ "S",
+ _parse_productions(
+ """S -> A c
S -> B c
A -> a
- B -> a"""))
- parser = grammar.parser()
- self.assertEqual(len(parser.conflicts), 1)
- # parser.conflicts is a set
- for conflict in parser.conflicts:
- for action in conflict.actions:
- self.assertTrue(isinstance(action, lr1.Reduce))
+ B -> a"""
+ ),
+ )
+ parser = grammar.parser()
+ self.assertEqual(len(parser.conflicts), 1)
+ # parser.conflicts is a set
+ for conflict in parser.conflicts:
+ for action in conflict.actions:
+ self.assertTrue(isinstance(action, lr1.Reduce))
- def test_grammar_with_shift_reduce_conflicts(self):
- grammar = lr1.Grammar("S", _parse_productions("""S -> A B
+ def test_grammar_with_shift_reduce_conflicts(self):
+ grammar = lr1.Grammar(
+ "S",
+ _parse_productions(
+ """S -> A B
A -> a
A ->
B -> a
- B ->"""))
- parser = grammar.parser()
- self.assertEqual(len(parser.conflicts), 1)
- # parser.conflicts is a set
- for conflict in parser.conflicts:
- reduces = 0
- shifts = 0
- for action in conflict.actions:
- if isinstance(action, lr1.Reduce):
- reduces += 1
- elif isinstance(action, lr1.Shift):
- shifts += 1
- self.assertEqual(1, reduces)
- self.assertEqual(1, shifts)
+ B ->"""
+ ),
+ )
+ parser = grammar.parser()
+ self.assertEqual(len(parser.conflicts), 1)
+ # parser.conflicts is a set
+ for conflict in parser.conflicts:
+ reduces = 0
+ shifts = 0
+ for action in conflict.actions:
+ if isinstance(action, lr1.Reduce):
+ reduces += 1
+ elif isinstance(action, lr1.Shift):
+ shifts += 1
+ self.assertEqual(1, reduces)
+ self.assertEqual(1, shifts)
- def test_item_str(self):
- self.assertEqual(
- "a -> b c ., d",
- str(lr1.make_item(parser_types.Production.parse("a -> b c"), 2, "d")))
- self.assertEqual(
- "a -> b . c, d",
- str(lr1.make_item(parser_types.Production.parse("a -> b c"), 1, "d")))
- self.assertEqual(
- "a -> . b c, d",
- str(lr1.make_item(parser_types.Production.parse("a -> b c"), 0, "d")))
- self.assertEqual(
- "a -> ., d",
- str(lr1.make_item(parser_types.Production.parse("a ->"), 0, "d")))
+ def test_item_str(self):
+ self.assertEqual(
+ "a -> b c ., d",
+ str(lr1.make_item(parser_types.Production.parse("a -> b c"), 2, "d")),
+ )
+ self.assertEqual(
+ "a -> b . c, d",
+ str(lr1.make_item(parser_types.Production.parse("a -> b c"), 1, "d")),
+ )
+ self.assertEqual(
+ "a -> . b c, d",
+ str(lr1.make_item(parser_types.Production.parse("a -> b c"), 0, "d")),
+ )
+ self.assertEqual(
+ "a -> ., d",
+ str(lr1.make_item(parser_types.Production.parse("a ->"), 0, "d")),
+ )
- def test_conflict_str(self):
- self.assertEqual("Conflict for 'A' in state 12: R vs S",
- str(lr1.Conflict(12, "'A'", ["R", "S"])))
- self.assertEqual("Conflict for 'A' in state 12: R vs S vs T",
- str(lr1.Conflict(12, "'A'", ["R", "S", "T"])))
+ def test_conflict_str(self):
+ self.assertEqual(
+ "Conflict for 'A' in state 12: R vs S",
+ str(lr1.Conflict(12, "'A'", ["R", "S"])),
+ )
+ self.assertEqual(
+ "Conflict for 'A' in state 12: R vs S vs T",
+ str(lr1.Conflict(12, "'A'", ["R", "S", "T"])),
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/module_ir.py b/compiler/front_end/module_ir.py
index c9ba765..bd27c8a 100644
--- a/compiler/front_end/module_ir.py
+++ b/compiler/front_end/module_ir.py
@@ -33,51 +33,54 @@
# Intermediate types; should not be found in the final IR.
class _List(object):
- """A list with source location information."""
- __slots__ = ('list', 'source_location')
+ """A list with source location information."""
- def __init__(self, l):
- assert isinstance(l, list), "_List object must wrap list, not '%r'" % l
- self.list = l
- self.source_location = ir_data.Location()
+ __slots__ = ("list", "source_location")
+
+ def __init__(self, l):
+ assert isinstance(l, list), "_List object must wrap list, not '%r'" % l
+ self.list = l
+ self.source_location = ir_data.Location()
class _ExpressionTail(object):
- """A fragment of an expression with an operator and right-hand side.
+ """A fragment of an expression with an operator and right-hand side.
- _ExpressionTail is the tail of an expression, consisting of an operator and
- the right-hand argument to the operator; for example, in the expression (6+8),
- the _ExpressionTail would be "+8".
+ _ExpressionTail is the tail of an expression, consisting of an operator and
+ the right-hand argument to the operator; for example, in the expression (6+8),
+ the _ExpressionTail would be "+8".
- This is used as a temporary object while converting the right-recursive
- "expression" and "times-expression" productions into left-associative
- Expressions.
+ This is used as a temporary object while converting the right-recursive
+ "expression" and "times-expression" productions into left-associative
+ Expressions.
- Attributes:
- operator: An ir_data.Word of the operator's name.
- expression: The expression on the right side of the operator.
- source_location: The source location of the operation fragment.
- """
- __slots__ = ('operator', 'expression', 'source_location')
+ Attributes:
+ operator: An ir_data.Word of the operator's name.
+ expression: The expression on the right side of the operator.
+ source_location: The source location of the operation fragment.
+ """
- def __init__(self, operator, expression):
- self.operator = operator
- self.expression = expression
- self.source_location = ir_data.Location()
+ __slots__ = ("operator", "expression", "source_location")
+
+ def __init__(self, operator, expression):
+ self.operator = operator
+ self.expression = expression
+ self.source_location = ir_data.Location()
class _FieldWithType(object):
- """A field with zero or more types defined inline with that field."""
- __slots__ = ('field', 'subtypes', 'source_location')
+ """A field with zero or more types defined inline with that field."""
- def __init__(self, field, subtypes=None):
- self.field = field
- self.subtypes = subtypes or []
- self.source_location = ir_data.Location()
+ __slots__ = ("field", "subtypes", "source_location")
+
+ def __init__(self, field, subtypes=None):
+ self.field = field
+ self.subtypes = subtypes or []
+ self.source_location = ir_data.Location()
def build_ir(parse_tree, used_productions=None):
- r"""Builds a module-level intermediate representation from a valid parse tree.
+ r"""Builds a module-level intermediate representation from a valid parse tree.
The parse tree is precisely dictated by the exact productions in the grammar
used by the parser, with no semantic information. _really_build_ir transforms
@@ -129,37 +132,39 @@
a forest of module IRs so that names from other modules can be resolved.
"""
- # TODO(b/140259131): Refactor _really_build_ir to be less recursive/use an
- # explicit stack.
- old_recursion_limit = sys.getrecursionlimit()
- sys.setrecursionlimit(16 * 1024) # ~8000 top-level entities in one module.
- try:
- result = _really_build_ir(parse_tree, used_productions)
- finally:
- sys.setrecursionlimit(old_recursion_limit)
- return result
+ # TODO(b/140259131): Refactor _really_build_ir to be less recursive/use an
+ # explicit stack.
+ old_recursion_limit = sys.getrecursionlimit()
+ sys.setrecursionlimit(16 * 1024) # ~8000 top-level entities in one module.
+ try:
+ result = _really_build_ir(parse_tree, used_productions)
+ finally:
+ sys.setrecursionlimit(old_recursion_limit)
+ return result
def _really_build_ir(parse_tree, used_productions):
- """Real implementation of build_ir()."""
- if used_productions is None:
- used_productions = set()
- if hasattr(parse_tree, 'children'):
- parsed_children = [_really_build_ir(child, used_productions)
- for child in parse_tree.children]
- used_productions.add(parse_tree.production)
- result = _handlers[parse_tree.production](*parsed_children)
- if parse_tree.source_location is not None:
- if result.source_location:
- ir_data_utils.update(result.source_location, parse_tree.source_location)
- else:
- result.source_location = ir_data_utils.copy(parse_tree.source_location)
- return result
- else:
- # For leaf nodes, the temporary "IR" is just the token. Higher-level rules
- # will translate it to a real IR.
- assert isinstance(parse_tree, parser_types.Token), str(parse_tree)
- return parse_tree
+ """Real implementation of build_ir()."""
+ if used_productions is None:
+ used_productions = set()
+ if hasattr(parse_tree, "children"):
+ parsed_children = [
+ _really_build_ir(child, used_productions) for child in parse_tree.children
+ ]
+ used_productions.add(parse_tree.production)
+ result = _handlers[parse_tree.production](*parsed_children)
+ if parse_tree.source_location is not None:
+ if result.source_location:
+ ir_data_utils.update(result.source_location, parse_tree.source_location)
+ else:
+ result.source_location = ir_data_utils.copy(parse_tree.source_location)
+ return result
+ else:
+ # For leaf nodes, the temporary "IR" is just the token. Higher-level rules
+ # will translate it to a real IR.
+ assert isinstance(parse_tree, parser_types.Token), str(parse_tree)
+ return parse_tree
+
# Map of productions to their handlers.
_handlers = {}
@@ -168,58 +173,59 @@
def _get_anonymous_field_name():
- global _anonymous_name_counter
- _anonymous_name_counter += 1
- return 'emboss_reserved_anonymous_field_{}'.format(_anonymous_name_counter)
+ global _anonymous_name_counter
+ _anonymous_name_counter += 1
+ return "emboss_reserved_anonymous_field_{}".format(_anonymous_name_counter)
def _handles(production_text):
- """_handles marks a function as the handler for a particular production."""
- production = parser_types.Production.parse(production_text)
+ """_handles marks a function as the handler for a particular production."""
+ production = parser_types.Production.parse(production_text)
- def handles(f):
- _handlers[production] = f
- return f
+ def handles(f):
+ _handlers[production] = f
+ return f
- return handles
+ return handles
def _make_prelude_import(position):
- """Helper function to construct a synthetic ir_data.Import for the prelude."""
- location = parser_types.make_location(position, position)
- return ir_data.Import(
- file_name=ir_data.String(text='', source_location=location),
- local_name=ir_data.Word(text='', source_location=location),
- source_location=location)
+ """Helper function to construct a synthetic ir_data.Import for the prelude."""
+ location = parser_types.make_location(position, position)
+ return ir_data.Import(
+ file_name=ir_data.String(text="", source_location=location),
+ local_name=ir_data.Word(text="", source_location=location),
+ source_location=location,
+ )
def _text_to_operator(text):
- """Converts an operator's textual name to its corresponding enum."""
- operations = {
- '+': ir_data.FunctionMapping.ADDITION,
- '-': ir_data.FunctionMapping.SUBTRACTION,
- '*': ir_data.FunctionMapping.MULTIPLICATION,
- '==': ir_data.FunctionMapping.EQUALITY,
- '!=': ir_data.FunctionMapping.INEQUALITY,
- '&&': ir_data.FunctionMapping.AND,
- '||': ir_data.FunctionMapping.OR,
- '>': ir_data.FunctionMapping.GREATER,
- '>=': ir_data.FunctionMapping.GREATER_OR_EQUAL,
- '<': ir_data.FunctionMapping.LESS,
- '<=': ir_data.FunctionMapping.LESS_OR_EQUAL,
- }
- return operations[text]
+ """Converts an operator's textual name to its corresponding enum."""
+ operations = {
+ "+": ir_data.FunctionMapping.ADDITION,
+ "-": ir_data.FunctionMapping.SUBTRACTION,
+ "*": ir_data.FunctionMapping.MULTIPLICATION,
+ "==": ir_data.FunctionMapping.EQUALITY,
+ "!=": ir_data.FunctionMapping.INEQUALITY,
+ "&&": ir_data.FunctionMapping.AND,
+ "||": ir_data.FunctionMapping.OR,
+ ">": ir_data.FunctionMapping.GREATER,
+ ">=": ir_data.FunctionMapping.GREATER_OR_EQUAL,
+ "<": ir_data.FunctionMapping.LESS,
+ "<=": ir_data.FunctionMapping.LESS_OR_EQUAL,
+ }
+ return operations[text]
def _text_to_function(text):
- """Converts a function's textual name to its corresponding enum."""
- functions = {
- '$max': ir_data.FunctionMapping.MAXIMUM,
- '$present': ir_data.FunctionMapping.PRESENCE,
- '$upper_bound': ir_data.FunctionMapping.UPPER_BOUND,
- '$lower_bound': ir_data.FunctionMapping.LOWER_BOUND,
- }
- return functions[text]
+ """Converts a function's textual name to its corresponding enum."""
+ functions = {
+ "$max": ir_data.FunctionMapping.MAXIMUM,
+ "$present": ir_data.FunctionMapping.PRESENCE,
+ "$upper_bound": ir_data.FunctionMapping.UPPER_BOUND,
+ "$lower_bound": ir_data.FunctionMapping.LOWER_BOUND,
+ }
+ return functions[text]
################################################################################
@@ -245,136 +251,153 @@
# A module file is a list of documentation, then imports, then top-level
# attributes, then type definitions. Any section may be missing.
# TODO(bolms): Should Emboss disallow completely empty files?
-@_handles('module -> comment-line* doc-line* import-line* attribute-line*'
- ' type-definition*')
+@_handles(
+ "module -> comment-line* doc-line* import-line* attribute-line*"
+ " type-definition*"
+)
def _file(leading_newlines, docs, imports, attributes, type_definitions):
- """Assembles the top-level IR for a module."""
- del leading_newlines # Unused.
- # Figure out the best synthetic source_location for the synthesized prelude
- # import.
- if imports.list:
- position = imports.list[0].source_location.start
- elif docs.list:
- position = docs.list[0].source_location.end
- elif attributes.list:
- position = attributes.list[0].source_location.start
- elif type_definitions.list:
- position = type_definitions.list[0].source_location.start
- else:
- position = 1, 1
+ """Assembles the top-level IR for a module."""
+ del leading_newlines # Unused.
+ # Figure out the best synthetic source_location for the synthesized prelude
+ # import.
+ if imports.list:
+ position = imports.list[0].source_location.start
+ elif docs.list:
+ position = docs.list[0].source_location.end
+ elif attributes.list:
+ position = attributes.list[0].source_location.start
+ elif type_definitions.list:
+ position = type_definitions.list[0].source_location.start
+ else:
+ position = 1, 1
- # If the source file is completely empty, build_ir won't automatically
- # populate the source_location attribute for the module.
- if (not docs.list and not imports.list and not attributes.list and
- not type_definitions.list):
- module_source_location = parser_types.make_location((1, 1), (1, 1))
- else:
- module_source_location = None
+ # If the source file is completely empty, build_ir won't automatically
+ # populate the source_location attribute for the module.
+ if (
+ not docs.list
+ and not imports.list
+ and not attributes.list
+ and not type_definitions.list
+ ):
+ module_source_location = parser_types.make_location((1, 1), (1, 1))
+ else:
+ module_source_location = None
- return ir_data.Module(
- documentation=docs.list,
- foreign_import=[_make_prelude_import(position)] + imports.list,
- attribute=attributes.list,
- type=type_definitions.list,
- source_location=module_source_location)
+ return ir_data.Module(
+ documentation=docs.list,
+ foreign_import=[_make_prelude_import(position)] + imports.list,
+ attribute=attributes.list,
+ type=type_definitions.list,
+ source_location=module_source_location,
+ )
-@_handles('import-line ->'
- ' "import" string-constant "as" snake-word Comment? eol')
+@_handles("import-line ->" ' "import" string-constant "as" snake-word Comment? eol')
def _import(import_, file_name, as_, local_name, comment, eol):
- del import_, as_, comment, eol # Unused
- return ir_data.Import(file_name=file_name, local_name=local_name)
+ del import_, as_, comment, eol # Unused
+ return ir_data.Import(file_name=file_name, local_name=local_name)
-@_handles('doc-line -> doc Comment? eol')
+@_handles("doc-line -> doc Comment? eol")
def _doc_line(doc, comment, eol):
- del comment, eol # Unused.
- return doc
+ del comment, eol # Unused.
+ return doc
-@_handles('doc -> Documentation')
+@_handles("doc -> Documentation")
def _doc(documentation):
- # As a special case, an empty documentation string may omit the trailing
- # space.
- if documentation.text == '--':
- doc_text = '-- '
- else:
- doc_text = documentation.text
- assert doc_text[0:3] == '-- ', (
- "Documentation token '{}' in unknown format.".format(
- documentation.text))
- return ir_data.Documentation(text=doc_text[3:])
+ # As a special case, an empty documentation string may omit the trailing
+ # space.
+ if documentation.text == "--":
+ doc_text = "-- "
+ else:
+ doc_text = documentation.text
+ assert doc_text[0:3] == "-- ", "Documentation token '{}' in unknown format.".format(
+ documentation.text
+ )
+ return ir_data.Documentation(text=doc_text[3:])
# A attribute-line is just a attribute on its own line.
-@_handles('attribute-line -> attribute Comment? eol')
+@_handles("attribute-line -> attribute Comment? eol")
def _attribute_line(attr, comment, eol):
- del comment, eol # Unused.
- return attr
+ del comment, eol # Unused.
+ return attr
# A attribute is [name = value].
-@_handles('attribute -> "[" attribute-context? "$default"?'
- ' snake-word ":" attribute-value "]"')
-def _attribute(open_bracket, context_specifier, default_specifier, name, colon,
- attribute_value, close_bracket):
- del open_bracket, colon, close_bracket # Unused.
- if context_specifier.list:
- return ir_data.Attribute(name=name,
- value=attribute_value,
- is_default=bool(default_specifier.list),
- back_end=context_specifier.list[0])
- else:
- return ir_data.Attribute(name=name,
- value=attribute_value,
- is_default=bool(default_specifier.list))
+@_handles(
+ 'attribute -> "[" attribute-context? "$default"?'
+ ' snake-word ":" attribute-value "]"'
+)
+def _attribute(
+ open_bracket,
+ context_specifier,
+ default_specifier,
+ name,
+ colon,
+ attribute_value,
+ close_bracket,
+):
+ del open_bracket, colon, close_bracket # Unused.
+ if context_specifier.list:
+ return ir_data.Attribute(
+ name=name,
+ value=attribute_value,
+ is_default=bool(default_specifier.list),
+ back_end=context_specifier.list[0],
+ )
+ else:
+ return ir_data.Attribute(
+ name=name, value=attribute_value, is_default=bool(default_specifier.list)
+ )
@_handles('attribute-context -> "(" snake-word ")"')
def _attribute_context(open_paren, context_name, close_paren):
- del open_paren, close_paren # Unused.
- return context_name
+ del open_paren, close_paren # Unused.
+ return context_name
-@_handles('attribute-value -> expression')
+@_handles("attribute-value -> expression")
def _attribute_value_expression(expression):
- return ir_data.AttributeValue(expression=expression)
+ return ir_data.AttributeValue(expression=expression)
-@_handles('attribute-value -> string-constant')
+@_handles("attribute-value -> string-constant")
def _attribute_value_string(string):
- return ir_data.AttributeValue(string_constant=string)
+ return ir_data.AttributeValue(string_constant=string)
-@_handles('boolean-constant -> BooleanConstant')
+@_handles("boolean-constant -> BooleanConstant")
def _boolean_constant(boolean):
- return ir_data.BooleanConstant(value=(boolean.text == 'true'))
+ return ir_data.BooleanConstant(value=(boolean.text == "true"))
-@_handles('string-constant -> String')
+@_handles("string-constant -> String")
def _string_constant(string):
- """Turns a String token into an ir_data.String, with proper unescaping.
+ """Turns a String token into an ir_data.String, with proper unescaping.
- Arguments:
- string: A String token.
+ Arguments:
+ string: A String token.
- Returns:
- An ir_data.String with the "text" field set to the unescaped value of
- string.text.
- """
- # TODO(bolms): If/when this logic becomes more complex (e.g., to handle \NNN
- # or \xNN escapes), extract this into a separate module with separate tests.
- assert string.text[0] == '"'
- assert string.text[-1] == '"'
- assert len(string.text) >= 2
- result = []
- for substring in re.split(r'(\\.)', string.text[1:-1]):
- if substring and substring[0] == '\\':
- assert len(substring) == 2
- result.append({'\\': '\\', '"': '"', 'n': '\n'}[substring[1]])
- else:
- result.append(substring)
- return ir_data.String(text=''.join(result))
+ Returns:
+ An ir_data.String with the "text" field set to the unescaped value of
+ string.text.
+ """
+ # TODO(bolms): If/when this logic becomes more complex (e.g., to handle \NNN
+ # or \xNN escapes), extract this into a separate module with separate tests.
+ assert string.text[0] == '"'
+ assert string.text[-1] == '"'
+ assert len(string.text) >= 2
+ result = []
+ for substring in re.split(r"(\\.)", string.text[1:-1]):
+ if substring and substring[0] == "\\":
+ assert len(substring) == 2
+ result.append({"\\": "\\", '"': '"', "n": "\n"}[substring[1]])
+ else:
+ result.append(substring)
+ return ir_data.String(text="".join(result))
# In Emboss, '&&' and '||' may not be mixed without parentheses. These are all
@@ -419,179 +442,205 @@
# and-expression-right -> '&&' equality-expression
#
# In either case, explicit parenthesization is handled elsewhere in the grammar.
-@_handles('logical-expression -> and-expression')
-@_handles('logical-expression -> or-expression')
-@_handles('logical-expression -> comparison-expression')
-@_handles('choice-expression -> logical-expression')
-@_handles('expression -> choice-expression')
+@_handles("logical-expression -> and-expression")
+@_handles("logical-expression -> or-expression")
+@_handles("logical-expression -> comparison-expression")
+@_handles("choice-expression -> logical-expression")
+@_handles("expression -> choice-expression")
def _expression(expression):
- return expression
+ return expression
# The `logical-expression`s here means that ?: can't be chained without
# parentheses. `x < 0 ? -1 : (x == 0 ? 0 : 1)` is OK, but `x < 0 ? -1 : x == 0
# ? 0 : 1` is not. Parentheses are also needed in the middle: `x <= 0 ? x < 0 ?
# -1 : 0 : 1` is not syntactically valid.
-@_handles('choice-expression -> logical-expression "?" logical-expression'
- ' ":" logical-expression')
+@_handles(
+ 'choice-expression -> logical-expression "?" logical-expression'
+ ' ":" logical-expression'
+)
def _choice_expression(condition, question, if_true, colon, if_false):
- location = parser_types.make_location(
- condition.source_location.start, if_false.source_location.end)
- operator_location = parser_types.make_location(
- question.source_location.start, colon.source_location.end)
- # The function_name is a bit weird, but should suffice for any error messages
- # that might need it.
- return ir_data.Expression(
- function=ir_data.Function(function=ir_data.FunctionMapping.CHOICE,
- args=[condition, if_true, if_false],
- function_name=ir_data.Word(
- text='?:',
- source_location=operator_location),
- source_location=location))
-
-
-@_handles('comparison-expression -> additive-expression')
-def _no_op_comparative_expression(expression):
- return expression
-
-
-@_handles('comparison-expression ->'
- ' additive-expression inequality-operator additive-expression')
-def _comparative_expression(left, operator, right):
- location = parser_types.make_location(
- left.source_location.start, right.source_location.end)
- return ir_data.Expression(
- function=ir_data.Function(function=_text_to_operator(operator.text),
- args=[left, right],
- function_name=operator,
- source_location=location))
-
-
-@_handles('additive-expression -> times-expression additive-expression-right*')
-@_handles('times-expression -> negation-expression times-expression-right*')
-@_handles('and-expression -> comparison-expression and-expression-right+')
-@_handles('or-expression -> comparison-expression or-expression-right+')
-def _binary_operator_expression(expression, expression_right):
- """Builds the IR for a chain of equal-precedence left-associative operations.
-
- _binary_operator_expression transforms a right-recursive list of expression
- tails into a left-associative Expression tree. For example, given the
- arguments:
-
- 6, (Tail("+", 7), Tail("-", 8), Tail("+", 10))
-
- _expression produces a structure like:
-
- Expression(Expression(Expression(6, "+", 7), "-", 8), "+", 10)
-
- This transformation is necessary because strict LR(1) grammars do not allow
- left recursion.
-
- Note that this method is used for several productions; each of those
- productions handles a different precedence level, but are identical in form.
-
- Arguments:
- expression: An ir_data.Expression which is the head of the (expr, operator,
- expr, operator, expr, ...) list.
- expression_right: A list of _ExpressionTails corresponding to the (operator,
- expr, operator, expr, ...) list that comes after expression.
-
- Returns:
- An ir_data.Expression with the correct recursive structure to represent a
- list of left-associative operations.
- """
- e = expression
- for right in expression_right.list:
location = parser_types.make_location(
- e.source_location.start, right.source_location.end)
- e = ir_data.Expression(
+ condition.source_location.start, if_false.source_location.end
+ )
+ operator_location = parser_types.make_location(
+ question.source_location.start, colon.source_location.end
+ )
+ # The function_name is a bit weird, but should suffice for any error messages
+ # that might need it.
+ return ir_data.Expression(
function=ir_data.Function(
- function=_text_to_operator(right.operator.text),
- args=[e, right.expression],
- function_name=right.operator,
- source_location=location),
- source_location=location)
- return e
+ function=ir_data.FunctionMapping.CHOICE,
+ args=[condition, if_true, if_false],
+ function_name=ir_data.Word(text="?:", source_location=operator_location),
+ source_location=location,
+ )
+ )
-@_handles('comparison-expression ->'
- ' additive-expression equality-expression-right+')
-@_handles('comparison-expression ->'
- ' additive-expression less-expression-right-list')
-@_handles('comparison-expression ->'
- ' additive-expression greater-expression-right-list')
-def _chained_comparison_expression(expression, expression_right):
- """Builds the IR for a chain of comparisons, like a == b == c.
+@_handles("comparison-expression -> additive-expression")
+def _no_op_comparative_expression(expression):
+ return expression
- Like _binary_operator_expression, _chained_comparison_expression transforms a
- right-recursive list of expression tails into a left-associative Expression
- tree. Unlike _binary_operator_expression, extra AND nodes are added. For
- example, the following expression:
- 0 <= b <= 64
-
- must be translated to the conceptually-equivalent expression:
-
- 0 <= b && b <= 64
-
- (The middle subexpression is duplicated -- this would be a problem in a
- programming language like C where expressions like `x++` have side effects,
- but side effects do not make sense in a data definition language like Emboss.)
-
- _chained_comparison_expression receives a left-hand head expression and a list
- of tails, like:
-
- 6, (Tail("<=", b), Tail("<=", 64))
-
- which it translates to a structure like:
-
- Expression(Expression(6, "<=", b), "&&", Expression(b, "<=", 64))
-
- The Emboss grammar is constructed such that sequences of "<", "<=", and "=="
- comparisons may be chained, and sequences of ">", ">=", and "==" can be
- chained, but greater and less-than comparisons may not; e.g., "b < 64 > a" is
- not allowed.
-
- Arguments:
- expression: An ir_data.Expression which is the head of the (expr, operator,
- expr, operator, expr, ...) list.
- expression_right: A list of _ExpressionTails corresponding to the (operator,
- expr, operator, expr, ...) list that comes after expression.
-
- Returns:
- An ir_data.Expression with the correct recursive structure to represent a
- chain of left-associative comparison operations.
- """
- sequence = [expression]
- for right in expression_right.list:
- sequence.append(right.operator)
- sequence.append(right.expression)
- comparisons = []
- for i in range(0, len(sequence) - 1, 2):
- left, operator, right = sequence[i:i+3]
+@_handles(
+ "comparison-expression ->"
+ " additive-expression inequality-operator additive-expression"
+)
+def _comparative_expression(left, operator, right):
location = parser_types.make_location(
- left.source_location.start, right.source_location.end)
- comparisons.append(ir_data.Expression(
+ left.source_location.start, right.source_location.end
+ )
+ return ir_data.Expression(
function=ir_data.Function(
function=_text_to_operator(operator.text),
args=[left, right],
function_name=operator,
- source_location=location),
- source_location=location))
- e = comparisons[0]
- for comparison in comparisons[1:]:
- location = parser_types.make_location(
- e.source_location.start, comparison.source_location.end)
- e = ir_data.Expression(
- function=ir_data.Function(
- function=ir_data.FunctionMapping.AND,
- args=[e, comparison],
- function_name=ir_data.Word(
- text='&&',
- source_location=comparison.function.args[0].source_location),
- source_location=location),
- source_location=location)
- return e
+ source_location=location,
+ )
+ )
+
+
+@_handles("additive-expression -> times-expression additive-expression-right*")
+@_handles("times-expression -> negation-expression times-expression-right*")
+@_handles("and-expression -> comparison-expression and-expression-right+")
+@_handles("or-expression -> comparison-expression or-expression-right+")
+def _binary_operator_expression(expression, expression_right):
+ """Builds the IR for a chain of equal-precedence left-associative operations.
+
+ _binary_operator_expression transforms a right-recursive list of expression
+ tails into a left-associative Expression tree. For example, given the
+ arguments:
+
+ 6, (Tail("+", 7), Tail("-", 8), Tail("+", 10))
+
+ _expression produces a structure like:
+
+ Expression(Expression(Expression(6, "+", 7), "-", 8), "+", 10)
+
+ This transformation is necessary because strict LR(1) grammars do not allow
+ left recursion.
+
+ Note that this method is used for several productions; each of those
+ productions handles a different precedence level, but are identical in form.
+
+ Arguments:
+ expression: An ir_data.Expression which is the head of the (expr, operator,
+ expr, operator, expr, ...) list.
+ expression_right: A list of _ExpressionTails corresponding to the (operator,
+ expr, operator, expr, ...) list that comes after expression.
+
+ Returns:
+ An ir_data.Expression with the correct recursive structure to represent a
+ list of left-associative operations.
+ """
+ e = expression
+ for right in expression_right.list:
+ location = parser_types.make_location(
+ e.source_location.start, right.source_location.end
+ )
+ e = ir_data.Expression(
+ function=ir_data.Function(
+ function=_text_to_operator(right.operator.text),
+ args=[e, right.expression],
+ function_name=right.operator,
+ source_location=location,
+ ),
+ source_location=location,
+ )
+ return e
+
+
+@_handles(
+ "comparison-expression ->" " additive-expression equality-expression-right+"
+)
+@_handles(
+ "comparison-expression ->" " additive-expression less-expression-right-list"
+)
+@_handles(
+ "comparison-expression ->" " additive-expression greater-expression-right-list"
+)
+def _chained_comparison_expression(expression, expression_right):
+ """Builds the IR for a chain of comparisons, like a == b == c.
+
+ Like _binary_operator_expression, _chained_comparison_expression transforms a
+ right-recursive list of expression tails into a left-associative Expression
+ tree. Unlike _binary_operator_expression, extra AND nodes are added. For
+ example, the following expression:
+
+ 0 <= b <= 64
+
+ must be translated to the conceptually-equivalent expression:
+
+ 0 <= b && b <= 64
+
+ (The middle subexpression is duplicated -- this would be a problem in a
+ programming language like C where expressions like `x++` have side effects,
+ but side effects do not make sense in a data definition language like Emboss.)
+
+ _chained_comparison_expression receives a left-hand head expression and a list
+ of tails, like:
+
+ 6, (Tail("<=", b), Tail("<=", 64))
+
+ which it translates to a structure like:
+
+ Expression(Expression(6, "<=", b), "&&", Expression(b, "<=", 64))
+
+ The Emboss grammar is constructed such that sequences of "<", "<=", and "=="
+ comparisons may be chained, and sequences of ">", ">=", and "==" can be
+ chained, but greater and less-than comparisons may not; e.g., "b < 64 > a" is
+ not allowed.
+
+ Arguments:
+ expression: An ir_data.Expression which is the head of the (expr, operator,
+ expr, operator, expr, ...) list.
+ expression_right: A list of _ExpressionTails corresponding to the (operator,
+ expr, operator, expr, ...) list that comes after expression.
+
+ Returns:
+ An ir_data.Expression with the correct recursive structure to represent a
+ chain of left-associative comparison operations.
+ """
+ sequence = [expression]
+ for right in expression_right.list:
+ sequence.append(right.operator)
+ sequence.append(right.expression)
+ comparisons = []
+ for i in range(0, len(sequence) - 1, 2):
+ left, operator, right = sequence[i : i + 3]
+ location = parser_types.make_location(
+ left.source_location.start, right.source_location.end
+ )
+ comparisons.append(
+ ir_data.Expression(
+ function=ir_data.Function(
+ function=_text_to_operator(operator.text),
+ args=[left, right],
+ function_name=operator,
+ source_location=location,
+ ),
+ source_location=location,
+ )
+ )
+ e = comparisons[0]
+ for comparison in comparisons[1:]:
+ location = parser_types.make_location(
+ e.source_location.start, comparison.source_location.end
+ )
+ e = ir_data.Expression(
+ function=ir_data.Function(
+ function=ir_data.FunctionMapping.AND,
+ args=[e, comparison],
+ function_name=ir_data.Word(
+ text="&&",
+ source_location=comparison.function.args[0].source_location,
+ ),
+ source_location=location,
+ ),
+ source_location=location,
+ )
+ return e
# _chained_comparison_expression, above, handles three types of chains: `a == b
@@ -629,279 +678,315 @@
#
# By using `equality-expression-right*` for the first symbol, only the first
# parse is possible.
-@_handles('greater-expression-right-list ->'
- ' equality-expression-right* greater-expression-right'
- ' equality-or-greater-expression-right*')
-@_handles('less-expression-right-list ->'
- ' equality-expression-right* less-expression-right'
- ' equality-or-less-expression-right*')
+@_handles(
+ "greater-expression-right-list ->"
+ " equality-expression-right* greater-expression-right"
+ " equality-or-greater-expression-right*"
+)
+@_handles(
+ "less-expression-right-list ->"
+ " equality-expression-right* less-expression-right"
+ " equality-or-less-expression-right*"
+)
def _chained_comparison_tails(start, middle, end):
- return _List(start.list + [middle] + end.list)
+ return _List(start.list + [middle] + end.list)
-@_handles('equality-or-greater-expression-right -> equality-expression-right')
-@_handles('equality-or-greater-expression-right -> greater-expression-right')
-@_handles('equality-or-less-expression-right -> equality-expression-right')
-@_handles('equality-or-less-expression-right -> less-expression-right')
+@_handles("equality-or-greater-expression-right -> equality-expression-right")
+@_handles("equality-or-greater-expression-right -> greater-expression-right")
+@_handles("equality-or-less-expression-right -> equality-expression-right")
+@_handles("equality-or-less-expression-right -> less-expression-right")
def _equality_or_less_or_greater(right):
- return right
+ return right
-@_handles('and-expression-right -> and-operator comparison-expression')
-@_handles('or-expression-right -> or-operator comparison-expression')
-@_handles('additive-expression-right -> additive-operator times-expression')
-@_handles('equality-expression-right -> equality-operator additive-expression')
-@_handles('greater-expression-right -> greater-operator additive-expression')
-@_handles('less-expression-right -> less-operator additive-expression')
-@_handles('times-expression-right ->'
- ' multiplicative-operator negation-expression')
+@_handles("and-expression-right -> and-operator comparison-expression")
+@_handles("or-expression-right -> or-operator comparison-expression")
+@_handles("additive-expression-right -> additive-operator times-expression")
+@_handles("equality-expression-right -> equality-operator additive-expression")
+@_handles("greater-expression-right -> greater-operator additive-expression")
+@_handles("less-expression-right -> less-operator additive-expression")
+@_handles("times-expression-right ->" " multiplicative-operator negation-expression")
def _expression_right_production(operator, expression):
- return _ExpressionTail(operator, expression)
+ return _ExpressionTail(operator, expression)
# This supports a single layer of unary plus/minus, so "+5" and "-value" are
# allowed, but "+-5" or "-+-something" are not.
-@_handles('negation-expression -> additive-operator bottom-expression')
+@_handles("negation-expression -> additive-operator bottom-expression")
def _negation_expression_with_operator(operator, expression):
- phantom_zero_location = ir_data.Location(start=operator.source_location.start,
- end=operator.source_location.start)
- return ir_data.Expression(
- function=ir_data.Function(
- function=_text_to_operator(operator.text),
- args=[ir_data.Expression(
- constant=ir_data.NumericConstant(
- value='0',
- source_location=phantom_zero_location),
- source_location=phantom_zero_location), expression],
- function_name=operator,
- source_location=ir_data.Location(
- start=operator.source_location.start,
- end=expression.source_location.end)))
+ phantom_zero_location = ir_data.Location(
+ start=operator.source_location.start, end=operator.source_location.start
+ )
+ return ir_data.Expression(
+ function=ir_data.Function(
+ function=_text_to_operator(operator.text),
+ args=[
+ ir_data.Expression(
+ constant=ir_data.NumericConstant(
+ value="0", source_location=phantom_zero_location
+ ),
+ source_location=phantom_zero_location,
+ ),
+ expression,
+ ],
+ function_name=operator,
+ source_location=ir_data.Location(
+ start=operator.source_location.start, end=expression.source_location.end
+ ),
+ )
+ )
-@_handles('negation-expression -> bottom-expression')
+@_handles("negation-expression -> bottom-expression")
def _negation_expression(expression):
- return expression
+ return expression
@_handles('bottom-expression -> "(" expression ")"')
def _bottom_expression_parentheses(open_paren, expression, close_paren):
- del open_paren, close_paren # Unused.
- return expression
+ del open_paren, close_paren # Unused.
+ return expression
@_handles('bottom-expression -> function-name "(" argument-list ")"')
def _bottom_expression_function(function, open_paren, arguments, close_paren):
- del open_paren # Unused.
- return ir_data.Expression(
- function=ir_data.Function(
- function=_text_to_function(function.text),
- args=arguments.list,
- function_name=function,
- source_location=ir_data.Location(
- start=function.source_location.start,
- end=close_paren.source_location.end)))
+ del open_paren # Unused.
+ return ir_data.Expression(
+ function=ir_data.Function(
+ function=_text_to_function(function.text),
+ args=arguments.list,
+ function_name=function,
+ source_location=ir_data.Location(
+ start=function.source_location.start,
+ end=close_paren.source_location.end,
+ ),
+ )
+ )
@_handles('comma-then-expression -> "," expression')
def _comma_then_expression(comma, expression):
- del comma # Unused.
- return expression
+ del comma # Unused.
+ return expression
-@_handles('argument-list -> expression comma-then-expression*')
+@_handles("argument-list -> expression comma-then-expression*")
def _argument_list(head, tail):
- tail.list.insert(0, head)
- return tail
+ tail.list.insert(0, head)
+ return tail
-@_handles('argument-list ->')
+@_handles("argument-list ->")
def _empty_argument_list():
- return _List([])
+ return _List([])
-@_handles('bottom-expression -> numeric-constant')
+@_handles("bottom-expression -> numeric-constant")
def _bottom_expression_from_numeric_constant(constant):
- return ir_data.Expression(constant=constant)
+ return ir_data.Expression(constant=constant)
-@_handles('bottom-expression -> constant-reference')
+@_handles("bottom-expression -> constant-reference")
def _bottom_expression_from_constant_reference(reference):
- return ir_data.Expression(constant_reference=reference)
+ return ir_data.Expression(constant_reference=reference)
-@_handles('bottom-expression -> builtin-reference')
+@_handles("bottom-expression -> builtin-reference")
def _bottom_expression_from_builtin(reference):
- return ir_data.Expression(builtin_reference=reference)
+ return ir_data.Expression(builtin_reference=reference)
-@_handles('bottom-expression -> boolean-constant')
+@_handles("bottom-expression -> boolean-constant")
def _bottom_expression_from_boolean_constant(boolean):
- return ir_data.Expression(boolean_constant=boolean)
+ return ir_data.Expression(boolean_constant=boolean)
-@_handles('bottom-expression -> field-reference')
+@_handles("bottom-expression -> field-reference")
def _bottom_expression_from_reference(reference):
- return reference
+ return reference
-@_handles('field-reference -> snake-reference field-reference-tail*')
+@_handles("field-reference -> snake-reference field-reference-tail*")
def _indirect_field_reference(field_reference, field_references):
- if field_references.source_location.HasField('end'):
- end_location = field_references.source_location.end
- else:
- end_location = field_reference.source_location.end
- return ir_data.Expression(field_reference=ir_data.FieldReference(
- path=[field_reference] + field_references.list,
- source_location=parser_types.make_location(
- field_reference.source_location.start, end_location)))
+ if field_references.source_location.HasField("end"):
+ end_location = field_references.source_location.end
+ else:
+ end_location = field_reference.source_location.end
+ return ir_data.Expression(
+ field_reference=ir_data.FieldReference(
+ path=[field_reference] + field_references.list,
+ source_location=parser_types.make_location(
+ field_reference.source_location.start, end_location
+ ),
+ )
+ )
# If "Type.field" ever becomes syntactically valid, it will be necessary to
# check that enum values are compile-time constants.
@_handles('field-reference-tail -> "." snake-reference')
def _field_reference_tail(dot, reference):
- del dot # Unused.
- return reference
+ del dot # Unused.
+ return reference
-@_handles('numeric-constant -> Number')
+@_handles("numeric-constant -> Number")
def _numeric_constant(number):
- # All types of numeric constant tokenize to the same symbol, because they are
- # interchangeable in source code.
- if number.text[0:2] == '0b':
- n = int(number.text.replace('_', '')[2:], 2)
- elif number.text[0:2] == '0x':
- n = int(number.text.replace('_', '')[2:], 16)
- else:
- n = int(number.text.replace('_', ''), 10)
- return ir_data.NumericConstant(value=str(n))
+ # All types of numeric constant tokenize to the same symbol, because they are
+ # interchangeable in source code.
+ if number.text[0:2] == "0b":
+ n = int(number.text.replace("_", "")[2:], 2)
+ elif number.text[0:2] == "0x":
+ n = int(number.text.replace("_", "")[2:], 16)
+ else:
+ n = int(number.text.replace("_", ""), 10)
+ return ir_data.NumericConstant(value=str(n))
-@_handles('type-definition -> struct')
-@_handles('type-definition -> bits')
-@_handles('type-definition -> enum')
-@_handles('type-definition -> external')
+@_handles("type-definition -> struct")
+@_handles("type-definition -> bits")
+@_handles("type-definition -> enum")
+@_handles("type-definition -> external")
def _type_definition(type_definition):
- return type_definition
+ return type_definition
# struct StructureName:
# ... fields ...
# bits BitName:
# ... fields ...
-@_handles('struct -> "struct" type-name delimited-parameter-definition-list?'
- ' ":" Comment? eol struct-body')
-@_handles('bits -> "bits" type-name delimited-parameter-definition-list? ":"'
- ' Comment? eol bits-body')
+@_handles(
+ 'struct -> "struct" type-name delimited-parameter-definition-list?'
+ ' ":" Comment? eol struct-body'
+)
+@_handles(
+ 'bits -> "bits" type-name delimited-parameter-definition-list? ":"'
+ " Comment? eol bits-body"
+)
def _structure(struct, name, parameters, colon, comment, newline, struct_body):
- """Composes the top-level IR for an Emboss structure."""
- del colon, comment, newline # Unused.
- ir_data_utils.builder(struct_body.structure).source_location.start.CopyFrom(
- struct.source_location.start)
- ir_data_utils.builder(struct_body.structure).source_location.end.CopyFrom(
- struct_body.source_location.end)
- if struct_body.name:
- ir_data_utils.update(struct_body.name, name)
- else:
- struct_body.name = ir_data_utils.copy(name)
- if parameters.list:
- struct_body.runtime_parameter.extend(parameters.list[0].list)
- return struct_body
+ """Composes the top-level IR for an Emboss structure."""
+ del colon, comment, newline # Unused.
+ ir_data_utils.builder(struct_body.structure).source_location.start.CopyFrom(
+ struct.source_location.start
+ )
+ ir_data_utils.builder(struct_body.structure).source_location.end.CopyFrom(
+ struct_body.source_location.end
+ )
+ if struct_body.name:
+ ir_data_utils.update(struct_body.name, name)
+ else:
+ struct_body.name = ir_data_utils.copy(name)
+ if parameters.list:
+ struct_body.runtime_parameter.extend(parameters.list[0].list)
+ return struct_body
-@_handles('delimited-parameter-definition-list ->'
- ' "(" parameter-definition-list ")"')
+@_handles(
+ "delimited-parameter-definition-list ->" ' "(" parameter-definition-list ")"'
+)
def _delimited_parameter_definition_list(open_paren, parameters, close_paren):
- del open_paren, close_paren # Unused
- return parameters
+ del open_paren, close_paren # Unused
+ return parameters
@_handles('parameter-definition -> snake-name ":" type')
def _parameter_definition(name, double_colon, parameter_type):
- del double_colon # Unused
- return ir_data.RuntimeParameter(name=name, physical_type_alias=parameter_type)
+ del double_colon # Unused
+ return ir_data.RuntimeParameter(name=name, physical_type_alias=parameter_type)
@_handles('parameter-definition-list-tail -> "," parameter-definition')
def _parameter_definition_list_tail(comma, parameter):
- del comma # Unused.
- return parameter
+ del comma # Unused.
+ return parameter
-@_handles('parameter-definition-list -> parameter-definition'
- ' parameter-definition-list-tail*')
+@_handles(
+ "parameter-definition-list -> parameter-definition"
+ " parameter-definition-list-tail*"
+)
def _parameter_definition_list(head, tail):
- tail.list.insert(0, head)
- return tail
+ tail.list.insert(0, head)
+ return tail
-@_handles('parameter-definition-list ->')
+@_handles("parameter-definition-list ->")
def _empty_parameter_definition_list():
- return _List([])
+ return _List([])
# The body of a struct: basically, the part after the first line.
-@_handles('struct-body -> Indent doc-line* attribute-line*'
- ' type-definition* struct-field-block Dedent')
+@_handles(
+ "struct-body -> Indent doc-line* attribute-line*"
+ " type-definition* struct-field-block Dedent"
+)
def _struct_body(indent, docs, attributes, types, fields, dedent):
- del indent, dedent # Unused.
- return _structure_body(docs, attributes, types, fields,
- ir_data.AddressableUnit.BYTE)
+ del indent, dedent # Unused.
+ return _structure_body(
+ docs, attributes, types, fields, ir_data.AddressableUnit.BYTE
+ )
def _structure_body(docs, attributes, types, fields, addressable_unit):
- """Constructs the body of a structure (bits or struct) definition."""
- return ir_data.TypeDefinition(
- structure=ir_data.Structure(field=[field.field for field in fields.list]),
- documentation=docs.list,
- attribute=attributes.list,
- subtype=types.list + [subtype for field in fields.list for subtype in
- field.subtypes],
- addressable_unit=addressable_unit)
+ """Constructs the body of a structure (bits or struct) definition."""
+ return ir_data.TypeDefinition(
+ structure=ir_data.Structure(field=[field.field for field in fields.list]),
+ documentation=docs.list,
+ attribute=attributes.list,
+ subtype=types.list
+ + [subtype for field in fields.list for subtype in field.subtypes],
+ addressable_unit=addressable_unit,
+ )
-@_handles('struct-field-block ->')
-@_handles('bits-field-block ->')
-@_handles('anonymous-bits-field-block ->')
+@_handles("struct-field-block ->")
+@_handles("bits-field-block ->")
+@_handles("anonymous-bits-field-block ->")
def _empty_field_block():
- return _List([])
+ return _List([])
-@_handles('struct-field-block ->'
- ' conditional-struct-field-block struct-field-block')
-@_handles('bits-field-block ->'
- ' conditional-bits-field-block bits-field-block')
-@_handles('anonymous-bits-field-block -> conditional-anonymous-bits-field-block'
- ' anonymous-bits-field-block')
+@_handles(
+ "struct-field-block ->" " conditional-struct-field-block struct-field-block"
+)
+@_handles("bits-field-block ->" " conditional-bits-field-block bits-field-block")
+@_handles(
+ "anonymous-bits-field-block -> conditional-anonymous-bits-field-block"
+ " anonymous-bits-field-block"
+)
def _conditional_block_plus_field_block(conditional_block, block):
- return _List(conditional_block.list + block.list)
+ return _List(conditional_block.list + block.list)
-@_handles('struct-field-block ->'
- ' unconditional-struct-field struct-field-block')
-@_handles('bits-field-block ->'
- ' unconditional-bits-field bits-field-block')
-@_handles('anonymous-bits-field-block ->'
- ' unconditional-anonymous-bits-field anonymous-bits-field-block')
+@_handles("struct-field-block ->" " unconditional-struct-field struct-field-block")
+@_handles("bits-field-block ->" " unconditional-bits-field bits-field-block")
+@_handles(
+ "anonymous-bits-field-block ->"
+ " unconditional-anonymous-bits-field anonymous-bits-field-block"
+)
def _unconditional_block_plus_field_block(field, block):
- """Prepends an unconditional field to block."""
- ir_data_utils.builder(field.field).existence_condition.source_location.CopyFrom(
- field.source_location)
- ir_data_utils.builder(field.field).existence_condition.boolean_constant.source_location.CopyFrom(
- field.source_location)
- ir_data_utils.builder(field.field).existence_condition.boolean_constant.value = True
- return _List([field] + block.list)
+ """Prepends an unconditional field to block."""
+ ir_data_utils.builder(field.field).existence_condition.source_location.CopyFrom(
+ field.source_location
+ )
+ ir_data_utils.builder(
+ field.field
+ ).existence_condition.boolean_constant.source_location.CopyFrom(
+ field.source_location
+ )
+ ir_data_utils.builder(field.field).existence_condition.boolean_constant.value = True
+ return _List([field] + block.list)
# Struct "fields" are regular fields, inline enums, bits, or structs, anonymous
# inline bits, or virtual fields.
-@_handles('unconditional-struct-field -> field')
-@_handles('unconditional-struct-field -> inline-enum-field-definition')
-@_handles('unconditional-struct-field -> inline-bits-field-definition')
-@_handles('unconditional-struct-field -> inline-struct-field-definition')
-@_handles('unconditional-struct-field -> anonymous-bits-field-definition')
-@_handles('unconditional-struct-field -> virtual-field')
+@_handles("unconditional-struct-field -> field")
+@_handles("unconditional-struct-field -> inline-enum-field-definition")
+@_handles("unconditional-struct-field -> inline-bits-field-definition")
+@_handles("unconditional-struct-field -> inline-struct-field-definition")
+@_handles("unconditional-struct-field -> anonymous-bits-field-definition")
+@_handles("unconditional-struct-field -> virtual-field")
# Bits fields are "regular" fields, inline enums or bits, or virtual fields.
#
# Inline structs and anonymous inline bits are not allowed inside of bits:
@@ -910,55 +995,66 @@
#
# Anonymous inline bits may not include virtual fields; instead, the virtual
# field should be a direct part of the enclosing structure.
-@_handles('unconditional-anonymous-bits-field -> field')
-@_handles('unconditional-anonymous-bits-field -> inline-enum-field-definition')
-@_handles('unconditional-anonymous-bits-field -> inline-bits-field-definition')
-@_handles('unconditional-bits-field -> unconditional-anonymous-bits-field')
-@_handles('unconditional-bits-field -> virtual-field')
+@_handles("unconditional-anonymous-bits-field -> field")
+@_handles("unconditional-anonymous-bits-field -> inline-enum-field-definition")
+@_handles("unconditional-anonymous-bits-field -> inline-bits-field-definition")
+@_handles("unconditional-bits-field -> unconditional-anonymous-bits-field")
+@_handles("unconditional-bits-field -> virtual-field")
def _unconditional_field(field):
- """Handles the unifying grammar production for a struct or bits field."""
- return field
+ """Handles the unifying grammar production for a struct or bits field."""
+ return field
# TODO(bolms): Add 'elif' and 'else' support.
# TODO(bolms): Should nested 'if' blocks be allowed?
-@_handles('conditional-struct-field-block ->'
- ' "if" expression ":" Comment? eol'
- ' Indent unconditional-struct-field+ Dedent')
-@_handles('conditional-bits-field-block ->'
- ' "if" expression ":" Comment? eol'
- ' Indent unconditional-bits-field+ Dedent')
-@_handles('conditional-anonymous-bits-field-block ->'
- ' "if" expression ":" Comment? eol'
- ' Indent unconditional-anonymous-bits-field+ Dedent')
-def _conditional_field_block(if_keyword, expression, colon, comment, newline,
- indent, fields, dedent):
- """Applies an existence_condition to each element of fields."""
- del if_keyword, newline, colon, comment, indent, dedent # Unused.
- for field in fields.list:
- condition = ir_data_utils.builder(field.field).existence_condition
- condition.CopyFrom(expression)
- condition.source_location.is_disjoint_from_parent = True
- return fields
+@_handles(
+ "conditional-struct-field-block ->"
+ ' "if" expression ":" Comment? eol'
+ " Indent unconditional-struct-field+ Dedent"
+)
+@_handles(
+ "conditional-bits-field-block ->"
+ ' "if" expression ":" Comment? eol'
+ " Indent unconditional-bits-field+ Dedent"
+)
+@_handles(
+ "conditional-anonymous-bits-field-block ->"
+ ' "if" expression ":" Comment? eol'
+ " Indent unconditional-anonymous-bits-field+ Dedent"
+)
+def _conditional_field_block(
+ if_keyword, expression, colon, comment, newline, indent, fields, dedent
+):
+ """Applies an existence_condition to each element of fields."""
+ del if_keyword, newline, colon, comment, indent, dedent # Unused.
+ for field in fields.list:
+ condition = ir_data_utils.builder(field.field).existence_condition
+ condition.CopyFrom(expression)
+ condition.source_location.is_disjoint_from_parent = True
+ return fields
# The body of a bit field definition: basically, the part after the first line.
-@_handles('bits-body -> Indent doc-line* attribute-line*'
- ' type-definition* bits-field-block Dedent')
+@_handles(
+ "bits-body -> Indent doc-line* attribute-line*"
+ " type-definition* bits-field-block Dedent"
+)
def _bits_body(indent, docs, attributes, types, fields, dedent):
- del indent, dedent # Unused.
- return _structure_body(docs, attributes, types, fields,
- ir_data.AddressableUnit.BIT)
+ del indent, dedent # Unused.
+ return _structure_body(docs, attributes, types, fields, ir_data.AddressableUnit.BIT)
# Inline bits (defined as part of a field) are more restricted than standalone
# bits.
-@_handles('anonymous-bits-body ->'
- ' Indent attribute-line* anonymous-bits-field-block Dedent')
+@_handles(
+ "anonymous-bits-body ->"
+ " Indent attribute-line* anonymous-bits-field-block Dedent"
+)
def _anonymous_bits_body(indent, attributes, fields, dedent):
- del indent, dedent # Unused.
- return _structure_body(_List([]), attributes, _List([]), fields,
- ir_data.AddressableUnit.BIT)
+ del indent, dedent # Unused.
+ return _structure_body(
+ _List([]), attributes, _List([]), fields, ir_data.AddressableUnit.BIT
+ )
# A field is:
@@ -967,30 +1063,43 @@
# -- doc
# [attr3: value]
# [attr4: value]
-@_handles('field ->'
- ' field-location type snake-name abbreviation? attribute* doc?'
- ' Comment? eol field-body?')
-def _field(location, field_type, name, abbreviation, attributes, doc, comment,
- newline, field_body):
- """Constructs an ir_data.Field from the given components."""
- del comment # Unused
- field_ir = ir_data.Field(location=location,
- type=field_type,
- name=name,
- attribute=attributes.list,
- documentation=doc.list)
- field = ir_data_utils.builder(field_ir)
- if field_body.list:
- field.attribute.extend(field_body.list[0].attribute)
- field.documentation.extend(field_body.list[0].documentation)
- if abbreviation.list:
- field.abbreviation.CopyFrom(abbreviation.list[0])
- field.source_location.start.CopyFrom(location.source_location.start)
- if field_body.source_location.HasField('end'):
- field.source_location.end.CopyFrom(field_body.source_location.end)
- else:
- field.source_location.end.CopyFrom(newline.source_location.end)
- return _FieldWithType(field=field_ir)
+@_handles(
+ "field ->"
+ " field-location type snake-name abbreviation? attribute* doc?"
+ " Comment? eol field-body?"
+)
+def _field(
+ location,
+ field_type,
+ name,
+ abbreviation,
+ attributes,
+ doc,
+ comment,
+ newline,
+ field_body,
+):
+ """Constructs an ir_data.Field from the given components."""
+ del comment # Unused
+ field_ir = ir_data.Field(
+ location=location,
+ type=field_type,
+ name=name,
+ attribute=attributes.list,
+ documentation=doc.list,
+ )
+ field = ir_data_utils.builder(field_ir)
+ if field_body.list:
+ field.attribute.extend(field_body.list[0].attribute)
+ field.documentation.extend(field_body.list[0].documentation)
+ if abbreviation.list:
+ field.abbreviation.CopyFrom(abbreviation.list[0])
+ field.source_location.start.CopyFrom(location.source_location.start)
+ if field_body.source_location.HasField("end"):
+ field.source_location.end.CopyFrom(field_body.source_location.end)
+ else:
+ field.source_location.end.CopyFrom(newline.source_location.end)
+ return _FieldWithType(field=field_ir)
# A "virtual field" is:
@@ -999,22 +1108,23 @@
# -- doc
# [attr1: value]
# [attr2: value]
-@_handles('virtual-field ->'
- ' "let" snake-name "=" expression Comment? eol field-body?')
+@_handles(
+ "virtual-field ->" ' "let" snake-name "=" expression Comment? eol field-body?'
+)
def _virtual_field(let, name, equals, value, comment, newline, field_body):
- """Constructs an ir_data.Field from the given components."""
- del equals, comment # Unused
- field_ir = ir_data.Field(read_transform=value, name=name)
- field = ir_data_utils.builder(field_ir)
- if field_body.list:
- field.attribute.extend(field_body.list[0].attribute)
- field.documentation.extend(field_body.list[0].documentation)
- field.source_location.start.CopyFrom(let.source_location.start)
- if field_body.source_location.HasField('end'):
- field.source_location.end.CopyFrom(field_body.source_location.end)
- else:
- field.source_location.end.CopyFrom(newline.source_location.end)
- return _FieldWithType(field=field_ir)
+ """Constructs an ir_data.Field from the given components."""
+ del equals, comment # Unused
+ field_ir = ir_data.Field(read_transform=value, name=name)
+ field = ir_data_utils.builder(field_ir)
+ if field_body.list:
+ field.attribute.extend(field_body.list[0].attribute)
+ field.documentation.extend(field_body.list[0].documentation)
+ field.source_location.start.CopyFrom(let.source_location.start)
+ if field_body.source_location.HasField("end"):
+ field.source_location.end.CopyFrom(field_body.source_location.end)
+ else:
+ field.source_location.end.CopyFrom(newline.source_location.end)
+ return _FieldWithType(field=field_ir)
# An inline enum is:
@@ -1025,252 +1135,292 @@
# [attr4: value]
# NAME = 10
# NAME2 = 20
-@_handles('inline-enum-field-definition ->'
- ' field-location "enum" snake-name abbreviation? ":" Comment? eol'
- ' enum-body')
-def _inline_enum_field(location, enum, name, abbreviation, colon, comment,
- newline, enum_body):
- """Constructs an ir_data.Field for an inline enum field."""
- del enum, colon, comment, newline # Unused.
- return _inline_type_field(location, name, abbreviation, enum_body)
+@_handles(
+ "inline-enum-field-definition ->"
+ ' field-location "enum" snake-name abbreviation? ":" Comment? eol'
+ " enum-body"
+)
+def _inline_enum_field(
+ location, enum, name, abbreviation, colon, comment, newline, enum_body
+):
+ """Constructs an ir_data.Field for an inline enum field."""
+ del enum, colon, comment, newline # Unused.
+ return _inline_type_field(location, name, abbreviation, enum_body)
@_handles(
- 'inline-struct-field-definition ->'
+ "inline-struct-field-definition ->"
' field-location "struct" snake-name abbreviation? ":" Comment? eol'
- ' struct-body')
-def _inline_struct_field(location, struct, name, abbreviation, colon, comment,
- newline, struct_body):
- del struct, colon, comment, newline # Unused.
- return _inline_type_field(location, name, abbreviation, struct_body)
+ " struct-body"
+)
+def _inline_struct_field(
+ location, struct, name, abbreviation, colon, comment, newline, struct_body
+):
+ del struct, colon, comment, newline # Unused.
+ return _inline_type_field(location, name, abbreviation, struct_body)
-@_handles('inline-bits-field-definition ->'
- ' field-location "bits" snake-name abbreviation? ":" Comment? eol'
- ' bits-body')
-def _inline_bits_field(location, bits, name, abbreviation, colon, comment,
- newline, bits_body):
- del bits, colon, comment, newline # Unused.
- return _inline_type_field(location, name, abbreviation, bits_body)
+@_handles(
+ "inline-bits-field-definition ->"
+ ' field-location "bits" snake-name abbreviation? ":" Comment? eol'
+ " bits-body"
+)
+def _inline_bits_field(
+ location, bits, name, abbreviation, colon, comment, newline, bits_body
+):
+ del bits, colon, comment, newline # Unused.
+ return _inline_type_field(location, name, abbreviation, bits_body)
def _inline_type_field(location, name, abbreviation, body):
- """Shared implementation of _inline_enum_field and _anonymous_bit_field."""
- field_ir = ir_data.Field(location=location,
- name=name,
- attribute=body.attribute,
- documentation=body.documentation)
- field = ir_data_utils.builder(field_ir)
- # All attributes should be attached to the field, not the type definition: if
- # the user wants to use type attributes, they should create a separate type
- # definition and reference it.
- del body.attribute[:]
- type_name = ir_data_utils.copy(name)
- ir_data_utils.builder(type_name).name.text = name_conversion.snake_to_camel(type_name.name.text)
- field.type.atomic_type.reference.source_name.extend([type_name.name])
- field.type.atomic_type.reference.source_location.CopyFrom(
- type_name.source_location)
- field.type.atomic_type.reference.is_local_name = True
- field.type.atomic_type.source_location.CopyFrom(type_name.source_location)
- field.type.source_location.CopyFrom(type_name.source_location)
- if abbreviation.list:
- field.abbreviation.CopyFrom(abbreviation.list[0])
- field.source_location.start.CopyFrom(location.source_location.start)
- ir_data_utils.builder(body.source_location).start.CopyFrom(location.source_location.start)
- if body.HasField('enumeration'):
- ir_data_utils.builder(body.enumeration).source_location.CopyFrom(body.source_location)
- else:
- assert body.HasField('structure')
- ir_data_utils.builder(body.structure).source_location.CopyFrom(body.source_location)
- ir_data_utils.builder(body).name.CopyFrom(type_name)
- field.source_location.end.CopyFrom(body.source_location.end)
- subtypes = [body] + list(body.subtype)
- del body.subtype[:]
- return _FieldWithType(field=field_ir, subtypes=subtypes)
+ """Shared implementation of _inline_enum_field and _anonymous_bit_field."""
+ field_ir = ir_data.Field(
+ location=location,
+ name=name,
+ attribute=body.attribute,
+ documentation=body.documentation,
+ )
+ field = ir_data_utils.builder(field_ir)
+ # All attributes should be attached to the field, not the type definition: if
+ # the user wants to use type attributes, they should create a separate type
+ # definition and reference it.
+ del body.attribute[:]
+ type_name = ir_data_utils.copy(name)
+ ir_data_utils.builder(type_name).name.text = name_conversion.snake_to_camel(
+ type_name.name.text
+ )
+ field.type.atomic_type.reference.source_name.extend([type_name.name])
+ field.type.atomic_type.reference.source_location.CopyFrom(type_name.source_location)
+ field.type.atomic_type.reference.is_local_name = True
+ field.type.atomic_type.source_location.CopyFrom(type_name.source_location)
+ field.type.source_location.CopyFrom(type_name.source_location)
+ if abbreviation.list:
+ field.abbreviation.CopyFrom(abbreviation.list[0])
+ field.source_location.start.CopyFrom(location.source_location.start)
+ ir_data_utils.builder(body.source_location).start.CopyFrom(
+ location.source_location.start
+ )
+ if body.HasField("enumeration"):
+ ir_data_utils.builder(body.enumeration).source_location.CopyFrom(
+ body.source_location
+ )
+ else:
+ assert body.HasField("structure")
+ ir_data_utils.builder(body.structure).source_location.CopyFrom(
+ body.source_location
+ )
+ ir_data_utils.builder(body).name.CopyFrom(type_name)
+ field.source_location.end.CopyFrom(body.source_location.end)
+ subtypes = [body] + list(body.subtype)
+ del body.subtype[:]
+ return _FieldWithType(field=field_ir, subtypes=subtypes)
-@_handles('anonymous-bits-field-definition ->'
- ' field-location "bits" ":" Comment? eol anonymous-bits-body')
-def _anonymous_bit_field(location, bits_keyword, colon, comment, newline,
- bits_body):
- """Constructs an ir_data.Field for an anonymous bit field."""
- del colon, comment, newline # Unused.
- name = ir_data.NameDefinition(
- name=ir_data.Word(
- text=_get_anonymous_field_name(),
- source_location=bits_keyword.source_location),
- source_location=bits_keyword.source_location,
- is_anonymous=True)
- return _inline_type_field(location, name, _List([]), bits_body)
+@_handles(
+ "anonymous-bits-field-definition ->"
+ ' field-location "bits" ":" Comment? eol anonymous-bits-body'
+)
+def _anonymous_bit_field(location, bits_keyword, colon, comment, newline, bits_body):
+ """Constructs an ir_data.Field for an anonymous bit field."""
+ del colon, comment, newline # Unused.
+ name = ir_data.NameDefinition(
+ name=ir_data.Word(
+ text=_get_anonymous_field_name(),
+ source_location=bits_keyword.source_location,
+ ),
+ source_location=bits_keyword.source_location,
+ is_anonymous=True,
+ )
+ return _inline_type_field(location, name, _List([]), bits_body)
-@_handles('field-body -> Indent doc-line* attribute-line* Dedent')
+@_handles("field-body -> Indent doc-line* attribute-line* Dedent")
def _field_body(indent, docs, attributes, dedent):
- del indent, dedent # Unused.
- return ir_data.Field(documentation=docs.list, attribute=attributes.list)
+ del indent, dedent # Unused.
+ return ir_data.Field(documentation=docs.list, attribute=attributes.list)
# A parenthetically-denoted abbreviation.
@_handles('abbreviation -> "(" snake-word ")"')
def _abbreviation(open_paren, word, close_paren):
- del open_paren, close_paren # Unused.
- return word
+ del open_paren, close_paren # Unused.
+ return word
# enum EnumName:
# ... values ...
@_handles('enum -> "enum" type-name ":" Comment? eol enum-body')
def _enum(enum, name, colon, comment, newline, enum_body):
- del colon, comment, newline # Unused.
- ir_data_utils.builder(enum_body.enumeration).source_location.start.CopyFrom(
- enum.source_location.start)
- ir_data_utils.builder(enum_body.enumeration).source_location.end.CopyFrom(
- enum_body.source_location.end)
- ir_data_utils.builder(enum_body).name.CopyFrom(name)
- return enum_body
+ del colon, comment, newline # Unused.
+ ir_data_utils.builder(enum_body.enumeration).source_location.start.CopyFrom(
+ enum.source_location.start
+ )
+ ir_data_utils.builder(enum_body.enumeration).source_location.end.CopyFrom(
+ enum_body.source_location.end
+ )
+ ir_data_utils.builder(enum_body).name.CopyFrom(name)
+ return enum_body
# [enum Foo:]
# name = value
# name = value
-@_handles('enum-body -> Indent doc-line* attribute-line* enum-value+ Dedent')
+@_handles("enum-body -> Indent doc-line* attribute-line* enum-value+ Dedent")
def _enum_body(indent, docs, attributes, values, dedent):
- del indent, dedent # Unused.
- return ir_data.TypeDefinition(
- enumeration=ir_data.Enum(value=values.list),
- documentation=docs.list,
- attribute=attributes.list,
- addressable_unit=ir_data.AddressableUnit.BIT)
+ del indent, dedent # Unused.
+ return ir_data.TypeDefinition(
+ enumeration=ir_data.Enum(value=values.list),
+ documentation=docs.list,
+ attribute=attributes.list,
+ addressable_unit=ir_data.AddressableUnit.BIT,
+ )
# name = value
-@_handles('enum-value -> '
- ' constant-name "=" expression attribute* doc? Comment? eol enum-value-body?')
-def _enum_value(name, equals, expression, attribute, documentation, comment, newline,
- body):
- del equals, comment, newline # Unused.
- result = ir_data.EnumValue(name=name,
- value=expression,
- documentation=documentation.list,
- attribute=attribute.list)
- if body.list:
- result.documentation.extend(body.list[0].documentation)
- result.attribute.extend(body.list[0].attribute)
- return result
+@_handles(
+ "enum-value -> "
+ ' constant-name "=" expression attribute* doc? Comment? eol enum-value-body?'
+)
+def _enum_value(
+ name, equals, expression, attribute, documentation, comment, newline, body
+):
+ del equals, comment, newline # Unused.
+ result = ir_data.EnumValue(
+ name=name,
+ value=expression,
+ documentation=documentation.list,
+ attribute=attribute.list,
+ )
+ if body.list:
+ result.documentation.extend(body.list[0].documentation)
+ result.attribute.extend(body.list[0].attribute)
+ return result
-@_handles('enum-value-body -> Indent doc-line* attribute-line* Dedent')
+@_handles("enum-value-body -> Indent doc-line* attribute-line* Dedent")
def _enum_value_body(indent, docs, attributes, dedent):
- del indent, dedent # Unused.
- return ir_data.EnumValue(documentation=docs.list, attribute=attributes.list)
+ del indent, dedent # Unused.
+ return ir_data.EnumValue(documentation=docs.list, attribute=attributes.list)
# An external is just a declaration that a type exists and has certain
# attributes.
@_handles('external -> "external" type-name ":" Comment? eol external-body')
def _external(external, name, colon, comment, newline, external_body):
- del colon, comment, newline # Unused.
- ir_data_utils.builder(external_body.source_location).start.CopyFrom(external.source_location.start)
- if external_body.name:
- ir_data_utils.update(external_body.name, name)
- else:
- external_body.name = ir_data_utils.copy(name)
- return external_body
+ del colon, comment, newline # Unused.
+ ir_data_utils.builder(external_body.source_location).start.CopyFrom(
+ external.source_location.start
+ )
+ if external_body.name:
+ ir_data_utils.update(external_body.name, name)
+ else:
+ external_body.name = ir_data_utils.copy(name)
+ return external_body
# This syntax implicitly requires either a documentation line or a attribute
# line, or it won't parse (because no Indent/Dedent tokens will be emitted).
-@_handles('external-body -> Indent doc-line* attribute-line* Dedent')
+@_handles("external-body -> Indent doc-line* attribute-line* Dedent")
def _external_body(indent, docs, attributes, dedent):
- return ir_data.TypeDefinition(
- external=ir_data.External(
- # Set source_location here, since it won't be set automatically.
- source_location=ir_data.Location(start=indent.source_location.start,
- end=dedent.source_location.end)),
- documentation=docs.list,
- attribute=attributes.list)
+ return ir_data.TypeDefinition(
+ external=ir_data.External(
+ # Set source_location here, since it won't be set automatically.
+ source_location=ir_data.Location(
+ start=indent.source_location.start, end=dedent.source_location.end
+ )
+ ),
+ documentation=docs.list,
+ attribute=attributes.list,
+ )
@_handles('field-location -> expression "[" "+" expression "]"')
def _field_location(start, open_bracket, plus, size, close_bracket):
- del open_bracket, plus, close_bracket # Unused.
- return ir_data.FieldLocation(start=start, size=size)
+ del open_bracket, plus, close_bracket # Unused.
+ return ir_data.FieldLocation(start=start, size=size)
@_handles('delimited-argument-list -> "(" argument-list ")"')
def _type_argument_list(open_paren, arguments, close_paren):
- del open_paren, close_paren # Unused
- return arguments
+ del open_paren, close_paren # Unused
+ return arguments
# A type is "TypeName" or "TypeName[length]" or "TypeName[length][length]", etc.
# An array type may have an empty length ("Type[]"). This is only valid for the
# outermost length (the last set of brackets), but that must be checked
# elsewhere.
-@_handles('type -> type-reference delimited-argument-list? type-size-specifier?'
- ' array-length-specifier*')
+@_handles(
+ "type -> type-reference delimited-argument-list? type-size-specifier?"
+ " array-length-specifier*"
+)
def _type(reference, parameters, size, array_spec):
- """Builds the IR for a type specifier."""
- base_type_source_location_end = reference.source_location.end
- atomic_type_source_location_end = reference.source_location.end
- if parameters.list:
- base_type_source_location_end = parameters.source_location.end
- atomic_type_source_location_end = parameters.source_location.end
- if size.list:
- base_type_source_location_end = size.source_location.end
- base_type_location = parser_types.make_location(
- reference.source_location.start,
- base_type_source_location_end)
- atomic_type_location = parser_types.make_location(
- reference.source_location.start,
- atomic_type_source_location_end)
- t = ir_data.Type(
- atomic_type=ir_data.AtomicType(
- reference=ir_data_utils.copy(reference),
- source_location=atomic_type_location,
- runtime_parameter=parameters.list[0].list if parameters.list else []),
- size_in_bits=size.list[0] if size.list else None,
- source_location=base_type_location)
- for length in array_spec.list:
- location = parser_types.make_location(
- t.source_location.start, length.source_location.end)
- if isinstance(length, ir_data.Expression):
- t = ir_data.Type(
- array_type=ir_data.ArrayType(base_type=t,
- element_count=length,
- source_location=location),
- source_location=location)
- elif isinstance(length, ir_data.Empty):
- t = ir_data.Type(
- array_type=ir_data.ArrayType(base_type=t,
- automatic=length,
- source_location=location),
- source_location=location)
- else:
- assert False, "Shouldn't be here."
- return t
+ """Builds the IR for a type specifier."""
+ base_type_source_location_end = reference.source_location.end
+ atomic_type_source_location_end = reference.source_location.end
+ if parameters.list:
+ base_type_source_location_end = parameters.source_location.end
+ atomic_type_source_location_end = parameters.source_location.end
+ if size.list:
+ base_type_source_location_end = size.source_location.end
+ base_type_location = parser_types.make_location(
+ reference.source_location.start, base_type_source_location_end
+ )
+ atomic_type_location = parser_types.make_location(
+ reference.source_location.start, atomic_type_source_location_end
+ )
+ t = ir_data.Type(
+ atomic_type=ir_data.AtomicType(
+ reference=ir_data_utils.copy(reference),
+ source_location=atomic_type_location,
+ runtime_parameter=parameters.list[0].list if parameters.list else [],
+ ),
+ size_in_bits=size.list[0] if size.list else None,
+ source_location=base_type_location,
+ )
+ for length in array_spec.list:
+ location = parser_types.make_location(
+ t.source_location.start, length.source_location.end
+ )
+ if isinstance(length, ir_data.Expression):
+ t = ir_data.Type(
+ array_type=ir_data.ArrayType(
+ base_type=t, element_count=length, source_location=location
+ ),
+ source_location=location,
+ )
+ elif isinstance(length, ir_data.Empty):
+ t = ir_data.Type(
+ array_type=ir_data.ArrayType(
+ base_type=t, automatic=length, source_location=location
+ ),
+ source_location=location,
+ )
+ else:
+ assert False, "Shouldn't be here."
+ return t
# TODO(bolms): Should symbolic names or expressions be allowed? E.g.,
# UInt:FIELD_SIZE or UInt:(16 + 16)?
@_handles('type-size-specifier -> ":" numeric-constant')
def _type_size_specifier(colon, numeric_constant):
- """handles the ":32" part of a type specifier like "UInt:32"."""
- del colon
- return ir_data.Expression(constant=numeric_constant)
+ """handles the ":32" part of a type specifier like "UInt:32"."""
+ del colon
+ return ir_data.Expression(constant=numeric_constant)
# The distinctions between different formats of NameDefinitions, Words, and
# References are enforced during parsing, but not propagated to the IR.
-@_handles('type-name -> type-word')
-@_handles('snake-name -> snake-word')
-@_handles('constant-name -> constant-word')
+@_handles("type-name -> type-word")
+@_handles("snake-name -> snake-word")
+@_handles("constant-name -> constant-word")
def _name(word):
- return ir_data.NameDefinition(name=word)
+ return ir_data.NameDefinition(name=word)
-@_handles('type-word -> CamelWord')
-@_handles('snake-word -> SnakeWord')
+@_handles("type-word -> CamelWord")
+@_handles("snake-word -> SnakeWord")
@_handles('builtin-field-word -> "$size_in_bits"')
@_handles('builtin-field-word -> "$size_in_bytes"')
@_handles('builtin-field-word -> "$max_size_in_bits"')
@@ -1280,7 +1430,7 @@
@_handles('builtin-word -> "$is_statically_sized"')
@_handles('builtin-word -> "$static_size_in_bits"')
@_handles('builtin-word -> "$next"')
-@_handles('constant-word -> ShoutyWord')
+@_handles("constant-word -> ShoutyWord")
@_handles('and-operator -> "&&"')
@_handles('or-operator -> "||"')
@_handles('less-operator -> "<="')
@@ -1297,28 +1447,29 @@
@_handles('function-name -> "$upper_bound"')
@_handles('function-name -> "$lower_bound"')
def _word(word):
- return ir_data.Word(text=word.text)
+ return ir_data.Word(text=word.text)
-@_handles('type-reference -> type-reference-tail')
-@_handles('constant-reference -> constant-reference-tail')
+@_handles("type-reference -> type-reference-tail")
+@_handles("constant-reference -> constant-reference-tail")
def _un_module_qualified_type_reference(reference):
- return reference
+ return reference
-@_handles('constant-reference-tail -> constant-word')
-@_handles('type-reference-tail -> type-word')
-@_handles('snake-reference -> snake-word')
-@_handles('snake-reference -> builtin-field-word')
+@_handles("constant-reference-tail -> constant-word")
+@_handles("type-reference-tail -> type-word")
+@_handles("snake-reference -> snake-word")
+@_handles("snake-reference -> builtin-field-word")
def _reference(word):
- return ir_data.Reference(source_name=[word])
+ return ir_data.Reference(source_name=[word])
-@_handles('builtin-reference -> builtin-word')
+@_handles("builtin-reference -> builtin-word")
def _builtin_reference(word):
- return ir_data.Reference(source_name=[word],
- canonical_name=ir_data.CanonicalName(
- object_path=[word.text]))
+ return ir_data.Reference(
+ source_name=[word],
+ canonical_name=ir_data.CanonicalName(object_path=[word.text]),
+ )
# Because constant-references ("Enum.NAME") are used in the same contexts as
@@ -1334,11 +1485,11 @@
# "snake_name.snake_name".
@_handles('constant-reference -> snake-reference "." constant-reference-tail')
def _module_qualified_constant_reference(new_head, dot, reference):
- del dot # Unused.
- new_source_name = list(new_head.source_name) + list(reference.source_name)
- del reference.source_name[:]
- reference.source_name.extend(new_source_name)
- return reference
+ del dot # Unused.
+ new_source_name = list(new_head.source_name) + list(reference.source_name)
+ del reference.source_name[:]
+ reference.source_name.extend(new_source_name)
+ return reference
@_handles('constant-reference-tail -> type-word "." constant-reference-tail')
@@ -1348,69 +1499,70 @@
@_handles('type-reference-tail -> type-word "." type-reference-tail')
@_handles('type-reference -> snake-word "." type-reference-tail')
def _qualified_reference(word, dot, reference):
- """Adds a name. or Type. qualification to the head of a reference."""
- del dot # Unused.
- new_source_name = [word] + list(reference.source_name)
- del reference.source_name[:]
- reference.source_name.extend(new_source_name)
- return reference
+ """Adds a name. or Type. qualification to the head of a reference."""
+ del dot # Unused.
+ new_source_name = [word] + list(reference.source_name)
+ del reference.source_name[:]
+ reference.source_name.extend(new_source_name)
+ return reference
# Arrays are properly translated to IR in _type().
@_handles('array-length-specifier -> "[" expression "]"')
def _array_length_specifier(open_bracket, length, close_bracket):
- del open_bracket, close_bracket # Unused.
- return length
+ del open_bracket, close_bracket # Unused.
+ return length
# An array specifier can end with empty brackets ("arr[3][]"), in which case the
# array's size is inferred from the size of its enclosing field.
@_handles('array-length-specifier -> "[" "]"')
def _auto_array_length_specifier(open_bracket, close_bracket):
- # Note that the Void's source_location is the space between the brackets (if
- # any).
- return ir_data.Empty(
- source_location=ir_data.Location(start=open_bracket.source_location.end,
- end=close_bracket.source_location.start))
+ # Note that the Void's source_location is the space between the brackets (if
+ # any).
+ return ir_data.Empty(
+ source_location=ir_data.Location(
+ start=open_bracket.source_location.end,
+ end=close_bracket.source_location.start,
+ )
+ )
@_handles('eol -> "\\n" comment-line*')
def _eol(eol, comments):
- del comments # Unused
- return eol
+ del comments # Unused
+ return eol
@_handles('comment-line -> Comment? "\\n"')
def _comment_line(comment, eol):
- del comment # Unused
- return eol
+ del comment # Unused
+ return eol
def _finalize_grammar():
- """_Finalize adds productions for foo*, foo+, and foo? symbols."""
- star_symbols = set()
- plus_symbols = set()
- option_symbols = set()
- for production in _handlers:
- for symbol in production.rhs:
- if symbol[-1] == '*':
- star_symbols.add(symbol[:-1])
- elif symbol[-1] == '+':
- # symbol+ relies on the rule for symbol*
- star_symbols.add(symbol[:-1])
- plus_symbols.add(symbol[:-1])
- elif symbol[-1] == '?':
- option_symbols.add(symbol[:-1])
- for symbol in star_symbols:
- _handles('{s}* -> {s} {s}*'.format(s=symbol))(
- lambda e, r: _List([e] + r.list))
- _handles('{s}* ->'.format(s=symbol))(lambda: _List([]))
- for symbol in plus_symbols:
- _handles('{s}+ -> {s} {s}*'.format(s=symbol))(
- lambda e, r: _List([e] + r.list))
- for symbol in option_symbols:
- _handles('{s}? -> {s}'.format(s=symbol))(lambda e: _List([e]))
- _handles('{s}? ->'.format(s=symbol))(lambda: _List([]))
+ """_Finalize adds productions for foo*, foo+, and foo? symbols."""
+ star_symbols = set()
+ plus_symbols = set()
+ option_symbols = set()
+ for production in _handlers:
+ for symbol in production.rhs:
+ if symbol[-1] == "*":
+ star_symbols.add(symbol[:-1])
+ elif symbol[-1] == "+":
+ # symbol+ relies on the rule for symbol*
+ star_symbols.add(symbol[:-1])
+ plus_symbols.add(symbol[:-1])
+ elif symbol[-1] == "?":
+ option_symbols.add(symbol[:-1])
+ for symbol in star_symbols:
+ _handles("{s}* -> {s} {s}*".format(s=symbol))(lambda e, r: _List([e] + r.list))
+ _handles("{s}* ->".format(s=symbol))(lambda: _List([]))
+ for symbol in plus_symbols:
+ _handles("{s}+ -> {s} {s}*".format(s=symbol))(lambda e, r: _List([e] + r.list))
+ for symbol in option_symbols:
+ _handles("{s}? -> {s}".format(s=symbol))(lambda e: _List([e]))
+ _handles("{s}? ->".format(s=symbol))(lambda: _List([]))
_finalize_grammar()
@@ -1420,6 +1572,6 @@
# These export the grammar used by module_ir so that parser_generator can build
# a parser for the same language.
-START_SYMBOL = 'module'
-EXPRESSION_START_SYMBOL = 'expression'
+START_SYMBOL = "module"
+EXPRESSION_START_SYMBOL = "expression"
PRODUCTIONS = list(_handlers.keys())
diff --git a/compiler/front_end/module_ir_test.py b/compiler/front_end/module_ir_test.py
index 5faa107..ad884a8 100644
--- a/compiler/front_end/module_ir_test.py
+++ b/compiler/front_end/module_ir_test.py
@@ -30,12 +30,16 @@
_TESTDATA_PATH = "testdata.golden"
_MINIMAL_SOURCE = pkgutil.get_data(
- _TESTDATA_PATH, "span_se_log_file_status.emb").decode(encoding="UTF-8")
+ _TESTDATA_PATH, "span_se_log_file_status.emb"
+).decode(encoding="UTF-8")
_MINIMAL_SAMPLE = parser.parse_module(
- tokenizer.tokenize(_MINIMAL_SOURCE, "")[0]).parse_tree
-_MINIMAL_SAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.Module,
+ tokenizer.tokenize(_MINIMAL_SOURCE, "")[0]
+).parse_tree
+_MINIMAL_SAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.Module,
pkgutil.get_data(_TESTDATA_PATH, "span_se_log_file_status.ir.txt").decode(
- encoding="UTF-8")
+ encoding="UTF-8"
+ ),
)
# _TEST_CASES contains test cases, separated by '===', that ensure that specific
@@ -3974,230 +3978,253 @@
def _get_test_cases():
- test_case = collections.namedtuple("test_case", ["name", "parse_tree", "ir"])
- result = []
- for case in _TEST_CASES.split("==="):
- name, emb, ir_text = case.split("---")
- name = name.strip()
- try:
- ir = ir_data_utils.IrDataSerializer.from_json(ir_data.Module, ir_text)
- except Exception:
- print(name)
- raise
- parse_result = parser.parse_module(tokenizer.tokenize(emb, "")[0])
- assert not parse_result.error, "{}:\n{}".format(name, parse_result.error)
- result.append(test_case(name, parse_result.parse_tree, ir))
- return result
+ test_case = collections.namedtuple("test_case", ["name", "parse_tree", "ir"])
+ result = []
+ for case in _TEST_CASES.split("==="):
+ name, emb, ir_text = case.split("---")
+ name = name.strip()
+ try:
+ ir = ir_data_utils.IrDataSerializer.from_json(ir_data.Module, ir_text)
+ except Exception:
+ print(name)
+ raise
+ parse_result = parser.parse_module(tokenizer.tokenize(emb, "")[0])
+ assert not parse_result.error, "{}:\n{}".format(name, parse_result.error)
+ result.append(test_case(name, parse_result.parse_tree, ir))
+ return result
def _get_negative_test_cases():
- test_case = collections.namedtuple("test_case",
- ["name", "text", "error_token"])
- result = []
- for case in _NEGATIVE_TEST_CASES.split("==="):
- name, error_token, text = case.split("---")
- name = name.strip()
- error_token = error_token.strip()
- result.append(test_case(name, text, error_token))
- return result
+ test_case = collections.namedtuple("test_case", ["name", "text", "error_token"])
+ result = []
+ for case in _NEGATIVE_TEST_CASES.split("==="):
+ name, error_token, text = case.split("---")
+ name = name.strip()
+ error_token = error_token.strip()
+ result.append(test_case(name, text, error_token))
+ return result
def _check_source_location(source_location, path, min_start, max_end):
- """Performs sanity checks on a source_location field.
+ """Performs sanity checks on a source_location field.
- Arguments:
- source_location: The source_location to check.
- path: The path, to use in error messages.
- min_start: A minimum value for source_location.start, or None.
- max_end: A maximum value for source_location.end, or None.
+ Arguments:
+ source_location: The source_location to check.
+ path: The path, to use in error messages.
+ min_start: A minimum value for source_location.start, or None.
+ max_end: A maximum value for source_location.end, or None.
- Returns:
- A list of error messages, or an empty list if no errors.
- """
- if source_location.is_disjoint_from_parent:
- # If source_location.is_disjoint_from_parent, then this source_location is
- # allowed to be outside of the parent's source_location.
- return []
+ Returns:
+ A list of error messages, or an empty list if no errors.
+ """
+ if source_location.is_disjoint_from_parent:
+ # If source_location.is_disjoint_from_parent, then this source_location is
+ # allowed to be outside of the parent's source_location.
+ return []
- result = []
- start = None
- end = None
- if not source_location.HasField("start"):
- result.append("{}.start missing".format(path))
- else:
- start = source_location.start
- if not source_location.HasField("end"):
- result.append("{}.end missing".format(path))
- else:
- end = source_location.end
-
- if start and end:
- if start.HasField("line") and end.HasField("line"):
- if start.line > end.line:
- result.append("{}.start.line > {}.end.line ({} vs {})".format(
- path, path, start.line, end.line))
- elif start.line == end.line:
- if (start.HasField("column") and end.HasField("column") and
- start.column > end.column):
- result.append("{}.start.column > {}.end.column ({} vs {})".format(
- path, path, start.column, end.column))
-
- for name, field in (("start", start), ("end", end)):
- if not field:
- continue
- if field.HasField("line"):
- if field.line <= 0:
- result.append("{}.{}.line <= 0 ({})".format(path, name, field.line))
+ result = []
+ start = None
+ end = None
+ if not source_location.HasField("start"):
+ result.append("{}.start missing".format(path))
else:
- result.append("{}.{}.line missing".format(path, name))
- if field.HasField("column"):
- if field.column <= 0:
- result.append("{}.{}.column <= 0 ({})".format(path, name, field.column))
+ start = source_location.start
+ if not source_location.HasField("end"):
+ result.append("{}.end missing".format(path))
else:
- result.append("{}.{}.column missing".format(path, name))
+ end = source_location.end
- if min_start and start:
- if min_start.line > start.line or (
- min_start.line == start.line and min_start.column > start.column):
- result.append("{}.start before parent start".format(path))
+ if start and end:
+ if start.HasField("line") and end.HasField("line"):
+ if start.line > end.line:
+ result.append(
+ "{}.start.line > {}.end.line ({} vs {})".format(
+ path, path, start.line, end.line
+ )
+ )
+ elif start.line == end.line:
+ if (
+ start.HasField("column")
+ and end.HasField("column")
+ and start.column > end.column
+ ):
+ result.append(
+ "{}.start.column > {}.end.column ({} vs {})".format(
+ path, path, start.column, end.column
+ )
+ )
- if max_end and end:
- if max_end.line < end.line or (
- max_end.line == end.line and max_end.column < end.column):
- result.append("{}.end after parent end".format(path))
+ for name, field in (("start", start), ("end", end)):
+ if not field:
+ continue
+ if field.HasField("line"):
+ if field.line <= 0:
+ result.append("{}.{}.line <= 0 ({})".format(path, name, field.line))
+ else:
+ result.append("{}.{}.line missing".format(path, name))
+ if field.HasField("column"):
+ if field.column <= 0:
+ result.append("{}.{}.column <= 0 ({})".format(path, name, field.column))
+ else:
+ result.append("{}.{}.column missing".format(path, name))
- return result
+ if min_start and start:
+ if min_start.line > start.line or (
+ min_start.line == start.line and min_start.column > start.column
+ ):
+ result.append("{}.start before parent start".format(path))
+
+ if max_end and end:
+ if max_end.line < end.line or (
+ max_end.line == end.line and max_end.column < end.column
+ ):
+ result.append("{}.end after parent end".format(path))
+
+ return result
def _check_all_source_locations(proto, path="", min_start=None, max_end=None):
- """Performs sanity checks on all source_locations in proto.
+ """Performs sanity checks on all source_locations in proto.
- Arguments:
- proto: The proto to recursively check.
- path: The path, to use in error messages.
- min_start: A minimum value for source_location.start, or None.
- max_end: A maximum value for source_location.end, or None.
+ Arguments:
+ proto: The proto to recursively check.
+ path: The path, to use in error messages.
+ min_start: A minimum value for source_location.start, or None.
+ max_end: A maximum value for source_location.end, or None.
- Returns:
- A list of error messages, or an empty list if no errors.
- """
- if path:
- path += "."
+ Returns:
+ A list of error messages, or an empty list if no errors.
+ """
+ if path:
+ path += "."
- errors = []
+ errors = []
- child_start = None
- child_end = None
- # Only check the source_location value if this proto message actually has a
- # source_location field.
- if proto.HasField("source_location"):
- errors.extend(_check_source_location(proto.source_location,
- path + "source_location",
- min_start, max_end))
- child_start = proto.source_location.start
- child_end = proto.source_location.end
+ child_start = None
+ child_end = None
+ # Only check the source_location value if this proto message actually has a
+ # source_location field.
+ if proto.HasField("source_location"):
+ errors.extend(
+ _check_source_location(
+ proto.source_location, path + "source_location", min_start, max_end
+ )
+ )
+ child_start = proto.source_location.start
+ child_end = proto.source_location.end
- for name, spec in ir_data_fields.field_specs(proto).items():
- if name == "source_location":
- continue
- if not proto.HasField(name):
- continue
- field_path = "{}{}".format(path, name)
- if spec.is_dataclass:
- if spec.is_sequence:
- index = 0
- for i in getattr(proto, name):
- item_path = "{}[{}]".format(field_path, index)
- index += 1
- errors.extend(
- _check_all_source_locations(i, item_path, child_start, child_end))
- else:
- errors.extend(_check_all_source_locations(getattr(proto, name),
- field_path, child_start,
- child_end))
+ for name, spec in ir_data_fields.field_specs(proto).items():
+ if name == "source_location":
+ continue
+ if not proto.HasField(name):
+ continue
+ field_path = "{}{}".format(path, name)
+ if spec.is_dataclass:
+ if spec.is_sequence:
+ index = 0
+ for i in getattr(proto, name):
+ item_path = "{}[{}]".format(field_path, index)
+ index += 1
+ errors.extend(
+ _check_all_source_locations(
+ i, item_path, child_start, child_end
+ )
+ )
+ else:
+ errors.extend(
+ _check_all_source_locations(
+ getattr(proto, name), field_path, child_start, child_end
+ )
+ )
- return errors
+ return errors
class ModuleIrTest(unittest.TestCase):
- """Tests the module_ir.build_ir() function."""
+ """Tests the module_ir.build_ir() function."""
- def test_build_ir(self):
- ir = module_ir.build_ir(_MINIMAL_SAMPLE)
- ir.source_text = _MINIMAL_SOURCE
- self.assertEqual(ir, _MINIMAL_SAMPLE_IR)
+ def test_build_ir(self):
+ ir = module_ir.build_ir(_MINIMAL_SAMPLE)
+ ir.source_text = _MINIMAL_SOURCE
+ self.assertEqual(ir, _MINIMAL_SAMPLE_IR)
- def test_production_coverage(self):
- """Checks that all grammar productions are used somewhere in tests."""
- used_productions = set()
- module_ir.build_ir(_MINIMAL_SAMPLE, used_productions)
- for test in _get_test_cases():
- module_ir.build_ir(test.parse_tree, used_productions)
- self.assertEqual(set(module_ir.PRODUCTIONS) - used_productions, set([]))
+ def test_production_coverage(self):
+ """Checks that all grammar productions are used somewhere in tests."""
+ used_productions = set()
+ module_ir.build_ir(_MINIMAL_SAMPLE, used_productions)
+ for test in _get_test_cases():
+ module_ir.build_ir(test.parse_tree, used_productions)
+ self.assertEqual(set(module_ir.PRODUCTIONS) - used_productions, set([]))
- def test_double_negative_non_compilation(self):
- """Checks that unparenthesized double unary minus/plus is a parse error."""
- for example in ("[x: - -3]", "[x: + -3]", "[x: - +3]", "[x: + +3]"):
- parse_result = parser.parse_module(tokenizer.tokenize(example, "")[0])
- self.assertTrue(parse_result.error)
- self.assertEqual(7, parse_result.error.token.source_location.start.column)
- for example in ("[x:-(-3)]", "[x:+(-3)]", "[x:-(+3)]", "[x:+(+3)]"):
- parse_result = parser.parse_module(tokenizer.tokenize(example, "")[0])
- self.assertFalse(parse_result.error)
+ def test_double_negative_non_compilation(self):
+ """Checks that unparenthesized double unary minus/plus is a parse error."""
+ for example in ("[x: - -3]", "[x: + -3]", "[x: - +3]", "[x: + +3]"):
+ parse_result = parser.parse_module(tokenizer.tokenize(example, "")[0])
+ self.assertTrue(parse_result.error)
+ self.assertEqual(7, parse_result.error.token.source_location.start.column)
+ for example in ("[x:-(-3)]", "[x:+(-3)]", "[x:-(+3)]", "[x:+(+3)]"):
+ parse_result = parser.parse_module(tokenizer.tokenize(example, "")[0])
+ self.assertFalse(parse_result.error)
def _make_superset_tests():
- def _make_superset_test(test):
+ def _make_superset_test(test):
- def test_case(self):
- ir = module_ir.build_ir(test.parse_tree)
- is_superset, error_message = test_util.proto_is_superset(ir, test.ir)
+ def test_case(self):
+ ir = module_ir.build_ir(test.parse_tree)
+ is_superset, error_message = test_util.proto_is_superset(ir, test.ir)
- self.assertTrue(
- is_superset,
- error_message + "\n" + ir_data_utils.IrDataSerializer(ir).to_json(indent=2) + "\n" +
- ir_data_utils.IrDataSerializer(test.ir).to_json(indent=2))
+ self.assertTrue(
+ is_superset,
+ error_message
+ + "\n"
+ + ir_data_utils.IrDataSerializer(ir).to_json(indent=2)
+ + "\n"
+ + ir_data_utils.IrDataSerializer(test.ir).to_json(indent=2),
+ )
- return test_case
+ return test_case
- for test in _get_test_cases():
- test_name = "test " + test.name + " proto superset"
- assert not hasattr(ModuleIrTest, test_name)
- setattr(ModuleIrTest, test_name, _make_superset_test(test))
+ for test in _get_test_cases():
+ test_name = "test " + test.name + " proto superset"
+ assert not hasattr(ModuleIrTest, test_name)
+ setattr(ModuleIrTest, test_name, _make_superset_test(test))
def _make_source_location_tests():
- def _make_source_location_test(test):
+ def _make_source_location_test(test):
- def test_case(self):
- error_list = _check_all_source_locations(
- module_ir.build_ir(test.parse_tree))
- self.assertFalse(error_list, "\n".join([test.name] + error_list))
+ def test_case(self):
+ error_list = _check_all_source_locations(
+ module_ir.build_ir(test.parse_tree)
+ )
+ self.assertFalse(error_list, "\n".join([test.name] + error_list))
- return test_case
+ return test_case
- for test in _get_test_cases():
- test_name = "test " + test.name + " source location"
- assert not hasattr(ModuleIrTest, test_name)
- setattr(ModuleIrTest, test_name, _make_source_location_test(test))
+ for test in _get_test_cases():
+ test_name = "test " + test.name + " source location"
+ assert not hasattr(ModuleIrTest, test_name)
+ setattr(ModuleIrTest, test_name, _make_source_location_test(test))
def _make_negative_tests():
- def _make_negative_test(test):
+ def _make_negative_test(test):
- def test_case(self):
- parse_result = parser.parse_module(tokenizer.tokenize(test.text, "")[0])
- self.assertEqual(test.error_token, parse_result.error.token.text.strip())
+ def test_case(self):
+ parse_result = parser.parse_module(tokenizer.tokenize(test.text, "")[0])
+ self.assertEqual(test.error_token, parse_result.error.token.text.strip())
- return test_case
+ return test_case
- for test in _get_negative_test_cases():
- test_name = "test " + test.name + " compilation failure"
- assert not hasattr(ModuleIrTest, test_name)
- setattr(ModuleIrTest, test_name, _make_negative_test(test))
+ for test in _get_negative_test_cases():
+ test_name = "test " + test.name + " compilation failure"
+ assert not hasattr(ModuleIrTest, test_name)
+ setattr(ModuleIrTest, test_name, _make_negative_test(test))
+
_make_negative_tests()
_make_superset_tests()
@@ -4205,4 +4232,4 @@
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/parser.py b/compiler/front_end/parser.py
index 6ece324..a6d7130 100644
--- a/compiler/front_end/parser.py
+++ b/compiler/front_end/parser.py
@@ -22,104 +22,110 @@
class ParserGenerationError(Exception):
- """An error occurred during parser generation."""
- pass
+ """An error occurred during parser generation."""
+
+ pass
def parse_error_examples(error_example_text):
- """Parses error examples from error_example_text.
+ """Parses error examples from error_example_text.
- Arguments:
- error_example_text: The text of an error example file.
+ Arguments:
+ error_example_text: The text of an error example file.
- Returns:
- A list of tuples, suitable for passing into generate_parser.
+ Returns:
+ A list of tuples, suitable for passing into generate_parser.
- Raises:
- ParserGenerationError: There is a problem parsing the error examples.
- """
- error_examples = error_example_text.split("\n" + "=" * 80 + "\n")
- result = []
- # Everything before the first "======" line is explanatory text: ignore it.
- for error_example in error_examples[1:]:
- message_and_examples = error_example.split("\n" + "-" * 80 + "\n")
- if len(message_and_examples) != 2:
- raise ParserGenerationError(
- "Expected one error message and one example section in:\n" +
- error_example)
- message, example_text = message_and_examples
- examples = example_text.split("\n---\n")
- for example in examples:
- # TODO(bolms): feed a line number into tokenize, so that tokenization
- # failures refer to the correct line within error_example_text.
- tokens, errors = tokenizer.tokenize(example, "")
- if errors:
- raise ParserGenerationError(str(errors))
+ Raises:
+ ParserGenerationError: There is a problem parsing the error examples.
+ """
+ error_examples = error_example_text.split("\n" + "=" * 80 + "\n")
+ result = []
+ # Everything before the first "======" line is explanatory text: ignore it.
+ for error_example in error_examples[1:]:
+ message_and_examples = error_example.split("\n" + "-" * 80 + "\n")
+ if len(message_and_examples) != 2:
+ raise ParserGenerationError(
+ "Expected one error message and one example section in:\n"
+ + error_example
+ )
+ message, example_text = message_and_examples
+ examples = example_text.split("\n---\n")
+ for example in examples:
+ # TODO(bolms): feed a line number into tokenize, so that tokenization
+ # failures refer to the correct line within error_example_text.
+ tokens, errors = tokenizer.tokenize(example, "")
+ if errors:
+ raise ParserGenerationError(str(errors))
- for i in range(len(tokens)):
- if tokens[i].symbol == "BadWord" and tokens[i].text == "$ANY":
- tokens[i] = lr1.ANY_TOKEN
+ for i in range(len(tokens)):
+ if tokens[i].symbol == "BadWord" and tokens[i].text == "$ANY":
+ tokens[i] = lr1.ANY_TOKEN
- error_token = None
- for i in range(len(tokens)):
- if tokens[i].symbol == "BadWord" and tokens[i].text == "$ERR":
- error_token = tokens[i + 1]
- del tokens[i]
- break
- else:
- raise ParserGenerationError(
- "No error token marker '$ERR' in:\n" + error_example)
+ error_token = None
+ for i in range(len(tokens)):
+ if tokens[i].symbol == "BadWord" and tokens[i].text == "$ERR":
+ error_token = tokens[i + 1]
+ del tokens[i]
+ break
+ else:
+ raise ParserGenerationError(
+ "No error token marker '$ERR' in:\n" + error_example
+ )
- result.append((tokens, error_token, message.strip(), example))
- return result
+ result.append((tokens, error_token, message.strip(), example))
+ return result
def generate_parser(start_symbol, productions, error_examples):
- """Generates a parser from grammar, and applies error_examples.
+ """Generates a parser from grammar, and applies error_examples.
- Arguments:
- start_symbol: the start symbol of the grammar (a string)
- productions: a list of parser_types.Production in the grammar
- error_examples: A list of (source tokens, error message, source text)
- tuples.
+ Arguments:
+ start_symbol: the start symbol of the grammar (a string)
+ productions: a list of parser_types.Production in the grammar
+ error_examples: A list of (source tokens, error message, source text)
+ tuples.
- Returns:
- A parser.
+ Returns:
+ A parser.
- Raises:
- ParserGenerationError: There is a problem generating the parser.
- """
- parser = lr1.Grammar(start_symbol, productions).parser()
- if parser.conflicts:
- raise ParserGenerationError("\n".join([str(c) for c in parser.conflicts]))
- for example in error_examples:
- mark_result = parser.mark_error(example[0], example[1], example[2])
- if mark_result:
- raise ParserGenerationError(
- "error marking example: {}\nExample:\n{}".format(
- mark_result, example[3]))
- return parser
+ Raises:
+ ParserGenerationError: There is a problem generating the parser.
+ """
+ parser = lr1.Grammar(start_symbol, productions).parser()
+ if parser.conflicts:
+ raise ParserGenerationError("\n".join([str(c) for c in parser.conflicts]))
+ for example in error_examples:
+ mark_result = parser.mark_error(example[0], example[1], example[2])
+ if mark_result:
+ raise ParserGenerationError(
+ "error marking example: {}\nExample:\n{}".format(
+ mark_result, example[3]
+ )
+ )
+ return parser
@simple_memoizer.memoize
def _load_module_parser():
- error_examples = parse_error_examples(
- resources.load("compiler.front_end", "error_examples"))
- return generate_parser(module_ir.START_SYMBOL, module_ir.PRODUCTIONS,
- error_examples)
+ error_examples = parse_error_examples(
+ resources.load("compiler.front_end", "error_examples")
+ )
+ return generate_parser(
+ module_ir.START_SYMBOL, module_ir.PRODUCTIONS, error_examples
+ )
@simple_memoizer.memoize
def _load_expression_parser():
- return generate_parser(module_ir.EXPRESSION_START_SYMBOL,
- module_ir.PRODUCTIONS, [])
+ return generate_parser(module_ir.EXPRESSION_START_SYMBOL, module_ir.PRODUCTIONS, [])
def parse_module(tokens):
- """Parses the provided Emboss token list into an Emboss module parse tree."""
- return _load_module_parser().parse(tokens)
+ """Parses the provided Emboss token list into an Emboss module parse tree."""
+ return _load_module_parser().parse(tokens)
def parse_expression(tokens):
- """Parses the provided Emboss token list into an expression parse tree."""
- return _load_expression_parser().parse(tokens)
+ """Parses the provided Emboss token list into an expression parse tree."""
+ return _load_expression_parser().parse(tokens)
diff --git a/compiler/front_end/parser_test.py b/compiler/front_end/parser_test.py
index 06cfd03..f5f111b 100644
--- a/compiler/front_end/parser_test.py
+++ b/compiler/front_end/parser_test.py
@@ -23,8 +23,8 @@
# TODO(bolms): This is repeated in lr1_test.py; separate into test utils?
def _parse_productions(*productions):
- """Parses text into a grammar by calling Production.parse on each line."""
- return [parser_types.Production.parse(p) for p in productions]
+ """Parses text into a grammar by calling Production.parse on each line."""
+ return [parser_types.Production.parse(p) for p in productions]
_EXAMPLE_DIVIDER = "\n" + "=" * 80 + "\n"
@@ -33,175 +33,266 @@
class ParserGeneratorTest(unittest.TestCase):
- """Tests parser.parse_error_examples and generate_parser."""
+ """Tests parser.parse_error_examples and generate_parser."""
- def test_parse_good_error_examples(self):
- errors = parser.parse_error_examples(
- _EXAMPLE_DIVIDER + # ======...
- "structure names must be Camel" + # Message.
- _MESSAGE_ERROR_DIVIDER + # ------...
- "struct $ERR FOO" + # First example.
- _ERROR_DIVIDER + # ---
- "struct $ERR foo" + # Second example.
- _EXAMPLE_DIVIDER + # ======...
- ' \n struct must be followed by ":" \n\n' + # Second message.
- _MESSAGE_ERROR_DIVIDER + # ------...
- "struct Foo $ERR") # Example for second message.
- self.assertEqual(tokenizer.tokenize("struct FOO", "")[0], errors[0][0])
- self.assertEqual("structure names must be Camel", errors[0][2])
- self.assertEqual(tokenizer.tokenize("struct foo", "")[0], errors[1][0])
- self.assertEqual("structure names must be Camel", errors[1][2])
- self.assertEqual(tokenizer.tokenize("struct Foo ", "")[0], errors[2][0])
- self.assertEqual('struct must be followed by ":"', errors[2][2])
+ def test_parse_good_error_examples(self):
+ errors = parser.parse_error_examples(
+ _EXAMPLE_DIVIDER # ======...
+ + "structure names must be Camel" # Message.
+ + _MESSAGE_ERROR_DIVIDER # ------...
+ + "struct $ERR FOO" # First example.
+ + _ERROR_DIVIDER # ---
+ + "struct $ERR foo" # Second example.
+ + _EXAMPLE_DIVIDER # ======...
+ + ' \n struct must be followed by ":" \n\n' # Second message.
+ + _MESSAGE_ERROR_DIVIDER # ------...
+ + "struct Foo $ERR"
+ ) # Example for second message.
+ self.assertEqual(tokenizer.tokenize("struct FOO", "")[0], errors[0][0])
+ self.assertEqual("structure names must be Camel", errors[0][2])
+ self.assertEqual(tokenizer.tokenize("struct foo", "")[0], errors[1][0])
+ self.assertEqual("structure names must be Camel", errors[1][2])
+ self.assertEqual(tokenizer.tokenize("struct Foo ", "")[0], errors[2][0])
+ self.assertEqual('struct must be followed by ":"', errors[2][2])
- def test_parse_good_wildcard_example(self):
- errors = parser.parse_error_examples(
- _EXAMPLE_DIVIDER + # ======...
- ' \n struct must be followed by ":" \n\n' + # Second message.
- _MESSAGE_ERROR_DIVIDER + # ------...
- "struct Foo $ERR $ANY")
- tokens = tokenizer.tokenize("struct Foo ", "")[0]
- # The $ANY token should come just before the end-of-line token in the parsed
- # result.
- tokens.insert(-1, lr1.ANY_TOKEN)
- self.assertEqual(tokens, errors[0][0])
- self.assertEqual('struct must be followed by ":"', errors[0][2])
+ def test_parse_good_wildcard_example(self):
+ errors = parser.parse_error_examples(
+ _EXAMPLE_DIVIDER # ======...
+ + ' \n struct must be followed by ":" \n\n' # Second message.
+ + _MESSAGE_ERROR_DIVIDER # ------...
+ + "struct Foo $ERR $ANY"
+ )
+ tokens = tokenizer.tokenize("struct Foo ", "")[0]
+ # The $ANY token should come just before the end-of-line token in the parsed
+ # result.
+ tokens.insert(-1, lr1.ANY_TOKEN)
+ self.assertEqual(tokens, errors[0][0])
+ self.assertEqual('struct must be followed by ":"', errors[0][2])
- def test_parse_with_no_error_marker(self):
- self.assertRaises(
- parser.ParserGenerationError,
- parser.parse_error_examples,
- _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "-- doc")
+ def test_parse_with_no_error_marker(self):
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.parse_error_examples,
+ _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "-- doc",
+ )
- def test_that_no_error_example_fails(self):
- self.assertRaises(parser.ParserGenerationError,
- parser.parse_error_examples,
- _EXAMPLE_DIVIDER + "msg" + _EXAMPLE_DIVIDER + "msg" +
- _MESSAGE_ERROR_DIVIDER + "example")
+ def test_that_no_error_example_fails(self):
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.parse_error_examples,
+ _EXAMPLE_DIVIDER
+ + "msg"
+ + _EXAMPLE_DIVIDER
+ + "msg"
+ + _MESSAGE_ERROR_DIVIDER
+ + "example",
+ )
- def test_that_message_example_divider_must_be_on_its_own_line(self):
- self.assertRaises(parser.ParserGenerationError,
- parser.parse_error_examples,
- _EXAMPLE_DIVIDER + "msg" + "-" * 80 + "example")
- self.assertRaises(parser.ParserGenerationError,
- parser.parse_error_examples,
- _EXAMPLE_DIVIDER + "msg\n" + "-" * 80 + "example")
- self.assertRaises(parser.ParserGenerationError,
- parser.parse_error_examples,
- _EXAMPLE_DIVIDER + "msg" + "-" * 80 + "\nexample")
- self.assertRaises(parser.ParserGenerationError,
- parser.parse_error_examples,
- _EXAMPLE_DIVIDER + "msg\n" + "-" * 80 + " \nexample")
+ def test_that_message_example_divider_must_be_on_its_own_line(self):
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.parse_error_examples,
+ _EXAMPLE_DIVIDER + "msg" + "-" * 80 + "example",
+ )
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.parse_error_examples,
+ _EXAMPLE_DIVIDER + "msg\n" + "-" * 80 + "example",
+ )
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.parse_error_examples,
+ _EXAMPLE_DIVIDER + "msg" + "-" * 80 + "\nexample",
+ )
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.parse_error_examples,
+ _EXAMPLE_DIVIDER + "msg\n" + "-" * 80 + " \nexample",
+ )
- def test_that_example_divider_must_be_on_its_own_line(self):
- self.assertRaises(
- parser.ParserGenerationError,
- parser.parse_error_examples,
- _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "example" + "=" * 80
- + "msg" + _MESSAGE_ERROR_DIVIDER + "example")
- self.assertRaises(
- parser.ParserGenerationError,
- parser.parse_error_examples,
- _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "example\n" + "=" *
- 80 + "msg" + _MESSAGE_ERROR_DIVIDER + "example")
- self.assertRaises(
- parser.ParserGenerationError,
- parser.parse_error_examples,
- _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "example" + "=" * 80
- + "\nmsg" + _MESSAGE_ERROR_DIVIDER + "example")
- self.assertRaises(
- parser.ParserGenerationError,
- parser.parse_error_examples,
- _EXAMPLE_DIVIDER + "msg" + _MESSAGE_ERROR_DIVIDER + "example\n" + "=" *
- 80 + " \nmsg" + _MESSAGE_ERROR_DIVIDER + "example")
+ def test_that_example_divider_must_be_on_its_own_line(self):
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.parse_error_examples,
+ _EXAMPLE_DIVIDER
+ + "msg"
+ + _MESSAGE_ERROR_DIVIDER
+ + "example"
+ + "=" * 80
+ + "msg"
+ + _MESSAGE_ERROR_DIVIDER
+ + "example",
+ )
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.parse_error_examples,
+ _EXAMPLE_DIVIDER
+ + "msg"
+ + _MESSAGE_ERROR_DIVIDER
+ + "example\n"
+ + "=" * 80
+ + "msg"
+ + _MESSAGE_ERROR_DIVIDER
+ + "example",
+ )
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.parse_error_examples,
+ _EXAMPLE_DIVIDER
+ + "msg"
+ + _MESSAGE_ERROR_DIVIDER
+ + "example"
+ + "=" * 80
+ + "\nmsg"
+ + _MESSAGE_ERROR_DIVIDER
+ + "example",
+ )
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.parse_error_examples,
+ _EXAMPLE_DIVIDER
+ + "msg"
+ + _MESSAGE_ERROR_DIVIDER
+ + "example\n"
+ + "=" * 80
+ + " \nmsg"
+ + _MESSAGE_ERROR_DIVIDER
+ + "example",
+ )
- def test_that_tokenization_failure_results_in_failure(self):
- self.assertRaises(
- parser.ParserGenerationError,
- parser.parse_error_examples,
- _EXAMPLE_DIVIDER + "message" + _MESSAGE_ERROR_DIVIDER + "|")
+ def test_that_tokenization_failure_results_in_failure(self):
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.parse_error_examples,
+ _EXAMPLE_DIVIDER + "message" + _MESSAGE_ERROR_DIVIDER + "|",
+ )
- def test_generate_parser(self):
- self.assertTrue(parser.generate_parser("C", _parse_productions("C -> s"),
- []))
- self.assertTrue(parser.generate_parser(
- "C", _parse_productions("C -> s", "C -> d"), []))
+ def test_generate_parser(self):
+ self.assertTrue(parser.generate_parser("C", _parse_productions("C -> s"), []))
+ self.assertTrue(
+ parser.generate_parser("C", _parse_productions("C -> s", "C -> d"), [])
+ )
- def test_generated_parser_error(self):
- test_parser = parser.generate_parser(
- "C", _parse_productions("C -> s", "C -> d"),
- [([parser_types.Token("s", "s", None),
- parser_types.Token("s", "s", None)],
- parser_types.Token("s", "s", None),
- "double s", "ss")])
- parse_result = test_parser.parse([parser_types.Token("s", "s", None),
- parser_types.Token("s", "s", None)])
- self.assertEqual(None, parse_result.parse_tree)
- self.assertEqual("double s", parse_result.error.code)
-
- def test_conflict_error(self):
- self.assertRaises(
- parser.ParserGenerationError,
- parser.generate_parser,
- "C", _parse_productions("C -> S", "C -> D", "S -> a", "D -> a"), [])
-
- def test_bad_mark_error(self):
- self.assertRaises(parser.ParserGenerationError,
- parser.generate_parser,
- "C", _parse_productions("C -> s", "C -> d"),
- [([parser_types.Token("s", "s", None),
- parser_types.Token("s", "s", None)],
+ def test_generated_parser_error(self):
+ test_parser = parser.generate_parser(
+ "C",
+ _parse_productions("C -> s", "C -> d"),
+ [
+ (
+ [
parser_types.Token("s", "s", None),
- "double s", "ss"),
- ([parser_types.Token("s", "s", None),
- parser_types.Token("s", "s", None)],
parser_types.Token("s", "s", None),
- "double 's'", "ss")])
- self.assertRaises(parser.ParserGenerationError,
- parser.generate_parser,
- "C", _parse_productions("C -> s", "C -> d"),
- [([parser_types.Token("s", "s", None)],
+ ],
+ parser_types.Token("s", "s", None),
+ "double s",
+ "ss",
+ )
+ ],
+ )
+ parse_result = test_parser.parse(
+ [parser_types.Token("s", "s", None), parser_types.Token("s", "s", None)]
+ )
+ self.assertEqual(None, parse_result.parse_tree)
+ self.assertEqual("double s", parse_result.error.code)
+
+ def test_conflict_error(self):
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.generate_parser,
+ "C",
+ _parse_productions("C -> S", "C -> D", "S -> a", "D -> a"),
+ [],
+ )
+
+ def test_bad_mark_error(self):
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.generate_parser,
+ "C",
+ _parse_productions("C -> s", "C -> d"),
+ [
+ (
+ [
parser_types.Token("s", "s", None),
- "single s", "s")])
+ parser_types.Token("s", "s", None),
+ ],
+ parser_types.Token("s", "s", None),
+ "double s",
+ "ss",
+ ),
+ (
+ [
+ parser_types.Token("s", "s", None),
+ parser_types.Token("s", "s", None),
+ ],
+ parser_types.Token("s", "s", None),
+ "double 's'",
+ "ss",
+ ),
+ ],
+ )
+ self.assertRaises(
+ parser.ParserGenerationError,
+ parser.generate_parser,
+ "C",
+ _parse_productions("C -> s", "C -> d"),
+ [
+ (
+ [parser_types.Token("s", "s", None)],
+ parser_types.Token("s", "s", None),
+ "single s",
+ "s",
+ )
+ ],
+ )
class ModuleParserTest(unittest.TestCase):
- """Tests for parser.parse_module().
+ """Tests for parser.parse_module().
- Correct parses should mostly be checked in conjunction with
- module_ir.build_ir, as the exact data structure returned by
- parser.parse_module() is determined by the grammar defined in module_ir.
- These tests only need to cover errors and sanity checking.
- """
+ Correct parses should mostly be checked in conjunction with
+ module_ir.build_ir, as the exact data structure returned by
+ parser.parse_module() is determined by the grammar defined in module_ir.
+ These tests only need to cover errors and sanity checking.
+ """
- def test_error_reporting_by_example(self):
- parse_result = parser.parse_module(
- tokenizer.tokenize("struct LogFileStatus:\n"
- " 0 [+4] UInt\n", "")[0])
- self.assertEqual(None, parse_result.parse_tree)
- self.assertEqual("A name is required for a struct field.",
- parse_result.error.code)
- self.assertEqual('"\\n"', parse_result.error.token.symbol)
- self.assertEqual(set(['"["', "SnakeWord", '"."', '":"', '"("']),
- parse_result.error.expected_tokens)
+ def test_error_reporting_by_example(self):
+ parse_result = parser.parse_module(
+ tokenizer.tokenize("struct LogFileStatus:\n" " 0 [+4] UInt\n", "")[0]
+ )
+ self.assertEqual(None, parse_result.parse_tree)
+ self.assertEqual(
+ "A name is required for a struct field.", parse_result.error.code
+ )
+ self.assertEqual('"\\n"', parse_result.error.token.symbol)
+ self.assertEqual(
+ set(['"["', "SnakeWord", '"."', '":"', '"("']),
+ parse_result.error.expected_tokens,
+ )
- def test_error_reporting_without_example(self):
- parse_result = parser.parse_module(
- tokenizer.tokenize("struct LogFileStatus:\n"
- " 0 [+4] UInt foo +\n", "")[0])
- self.assertEqual(None, parse_result.parse_tree)
- self.assertEqual(None, parse_result.error.code)
- self.assertEqual('"+"', parse_result.error.token.symbol)
- self.assertEqual(set(['"("', '"\\n"', '"["', "Documentation", "Comment"]),
- parse_result.error.expected_tokens)
+ def test_error_reporting_without_example(self):
+ parse_result = parser.parse_module(
+ tokenizer.tokenize(
+ "struct LogFileStatus:\n" " 0 [+4] UInt foo +\n", ""
+ )[0]
+ )
+ self.assertEqual(None, parse_result.parse_tree)
+ self.assertEqual(None, parse_result.error.code)
+ self.assertEqual('"+"', parse_result.error.token.symbol)
+ self.assertEqual(
+ set(['"("', '"\\n"', '"["', "Documentation", "Comment"]),
+ parse_result.error.expected_tokens,
+ )
- def test_ok_parse(self):
- parse_result = parser.parse_module(
- tokenizer.tokenize("struct LogFileStatus:\n"
- " 0 [+4] UInt foo\n", "")[0])
- self.assertTrue(parse_result.parse_tree)
- self.assertEqual(None, parse_result.error)
+ def test_ok_parse(self):
+ parse_result = parser.parse_module(
+ tokenizer.tokenize(
+ "struct LogFileStatus:\n" " 0 [+4] UInt foo\n", ""
+ )[0]
+ )
+ self.assertTrue(parse_result.parse_tree)
+ self.assertEqual(None, parse_result.error)
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/symbol_resolver.py b/compiler/front_end/symbol_resolver.py
index 9328e8f..498b1a9 100644
--- a/compiler/front_end/symbol_resolver.py
+++ b/compiler/front_end/symbol_resolver.py
@@ -33,497 +33,602 @@
def ambiguous_name_error(file_name, location, name, candidate_locations):
- """A name cannot be resolved because there are two or more candidates."""
- result = [error.error(file_name, location, "Ambiguous name '{}'".format(name))
- ]
- for location in sorted(candidate_locations):
- result.append(error.note(location.file, location.location,
- "Possible resolution"))
- return result
+ """A name cannot be resolved because there are two or more candidates."""
+ result = [error.error(file_name, location, "Ambiguous name '{}'".format(name))]
+ for location in sorted(candidate_locations):
+ result.append(
+ error.note(location.file, location.location, "Possible resolution")
+ )
+ return result
def duplicate_name_error(file_name, location, name, original_location):
- """A name is defined two or more times."""
- return [error.error(file_name, location, "Duplicate name '{}'".format(name)),
- error.note(original_location.file, original_location.location,
- "Original definition")]
+ """A name is defined two or more times."""
+ return [
+ error.error(file_name, location, "Duplicate name '{}'".format(name)),
+ error.note(
+ original_location.file, original_location.location, "Original definition"
+ ),
+ ]
def missing_name_error(file_name, location, name):
- return [error.error(file_name, location, "No candidate for '{}'".format(name))
- ]
+ return [error.error(file_name, location, "No candidate for '{}'".format(name))]
def array_subfield_error(file_name, location, name):
- return [error.error(file_name, location,
- "Cannot access member of array '{}'".format(name))]
+ return [
+ error.error(
+ file_name, location, "Cannot access member of array '{}'".format(name)
+ )
+ ]
def noncomposite_subfield_error(file_name, location, name):
- return [error.error(file_name, location,
- "Cannot access member of noncomposite field '{}'".format(
- name))]
+ return [
+ error.error(
+ file_name,
+ location,
+ "Cannot access member of noncomposite field '{}'".format(name),
+ )
+ ]
def _nested_name(canonical_name, name):
- """Creates a new CanonicalName with name appended to the object_path."""
- return ir_data.CanonicalName(
- module_file=canonical_name.module_file,
- object_path=list(canonical_name.object_path) + [name])
+ """Creates a new CanonicalName with name appended to the object_path."""
+ return ir_data.CanonicalName(
+ module_file=canonical_name.module_file,
+ object_path=list(canonical_name.object_path) + [name],
+ )
class _Scope(dict):
- """A _Scope holds data for a symbol.
+ """A _Scope holds data for a symbol.
- A _Scope is a dict with some additional attributes. Lexically nested names
- are kept in the dict, and bookkeeping is kept in the additional attributes.
+ A _Scope is a dict with some additional attributes. Lexically nested names
+ are kept in the dict, and bookkeeping is kept in the additional attributes.
- For example, each module should have a child _Scope for each type contained in
- the module. `struct` and `bits` types should have nested _Scopes for each
- field; `enum` types should have nested scopes for each enumerated name.
+ For example, each module should have a child _Scope for each type contained in
+ the module. `struct` and `bits` types should have nested _Scopes for each
+ field; `enum` types should have nested scopes for each enumerated name.
- Attributes:
- canonical_name: The absolute name of this symbol; e.g. ("file.emb",
- "TypeName", "SubTypeName", "field_name")
- source_location: The ir_data.SourceLocation where this symbol is defined.
- visibility: LOCAL, PRIVATE, or SEARCHABLE; see below.
- alias: If set, this name is merely a pointer to another name.
- """
- __slots__ = ("canonical_name", "source_location", "visibility", "alias")
+ Attributes:
+ canonical_name: The absolute name of this symbol; e.g. ("file.emb",
+ "TypeName", "SubTypeName", "field_name")
+ source_location: The ir_data.SourceLocation where this symbol is defined.
+ visibility: LOCAL, PRIVATE, or SEARCHABLE; see below.
+ alias: If set, this name is merely a pointer to another name.
+ """
- # A LOCAL name is visible outside of its enclosing scope, but should not be
- # found when searching for a name. That is, this name should be matched in
- # the tail of a qualified reference (the 'bar' in 'foo.bar'), but not when
- # searching for names (the 'foo' in 'foo.bar' should not match outside of
- # 'foo's scope). This applies to public field names.
- LOCAL = object()
+ __slots__ = ("canonical_name", "source_location", "visibility", "alias")
- # A PRIVATE name is similar to LOCAL except that it is never visible outside
- # its enclosing scope. This applies to abbreviations of field names: if 'a'
- # is an abbreviation for field 'apple', then 'foo.a' is not a valid reference;
- # instead it should be 'foo.apple'.
- PRIVATE = object()
+ # A LOCAL name is visible outside of its enclosing scope, but should not be
+ # found when searching for a name. That is, this name should be matched in
+ # the tail of a qualified reference (the 'bar' in 'foo.bar'), but not when
+ # searching for names (the 'foo' in 'foo.bar' should not match outside of
+ # 'foo's scope). This applies to public field names.
+ LOCAL = object()
- # A SEARCHABLE name is visible as long as it is in a scope in the search list.
- # This applies to type names ('Foo'), which may be found from many scopes.
- SEARCHABLE = object()
+ # A PRIVATE name is similar to LOCAL except that it is never visible outside
+ # its enclosing scope. This applies to abbreviations of field names: if 'a'
+ # is an abbreviation for field 'apple', then 'foo.a' is not a valid reference;
+ # instead it should be 'foo.apple'.
+ PRIVATE = object()
- def __init__(self, canonical_name, source_location, visibility, alias=None):
- super(_Scope, self).__init__()
- self.canonical_name = canonical_name
- self.source_location = source_location
- self.visibility = visibility
- self.alias = alias
+ # A SEARCHABLE name is visible as long as it is in a scope in the search list.
+ # This applies to type names ('Foo'), which may be found from many scopes.
+ SEARCHABLE = object()
+
+ def __init__(self, canonical_name, source_location, visibility, alias=None):
+ super(_Scope, self).__init__()
+ self.canonical_name = canonical_name
+ self.source_location = source_location
+ self.visibility = visibility
+ self.alias = alias
def _add_name_to_scope(name_ir, scope, canonical_name, visibility, errors):
- """Adds the given name_ir to the given scope."""
- name = name_ir.text
- new_scope = _Scope(canonical_name, name_ir.source_location, visibility)
- if name in scope:
- errors.append(duplicate_name_error(
- scope.canonical_name.module_file, name_ir.source_location, name,
- FileLocation(scope[name].canonical_name.module_file,
- scope[name].source_location)))
- else:
- scope[name] = new_scope
- return new_scope
+ """Adds the given name_ir to the given scope."""
+ name = name_ir.text
+ new_scope = _Scope(canonical_name, name_ir.source_location, visibility)
+ if name in scope:
+ errors.append(
+ duplicate_name_error(
+ scope.canonical_name.module_file,
+ name_ir.source_location,
+ name,
+ FileLocation(
+ scope[name].canonical_name.module_file, scope[name].source_location
+ ),
+ )
+ )
+ else:
+ scope[name] = new_scope
+ return new_scope
def _add_name_to_scope_and_normalize(name_ir, scope, visibility, errors):
- """Adds the given name_ir to scope and sets its canonical_name."""
- name = name_ir.name.text
- canonical_name = _nested_name(scope.canonical_name, name)
- ir_data_utils.builder(name_ir).canonical_name.CopyFrom(canonical_name)
- return _add_name_to_scope(name_ir.name, scope, canonical_name, visibility,
- errors)
+ """Adds the given name_ir to scope and sets its canonical_name."""
+ name = name_ir.name.text
+ canonical_name = _nested_name(scope.canonical_name, name)
+ ir_data_utils.builder(name_ir).canonical_name.CopyFrom(canonical_name)
+ return _add_name_to_scope(name_ir.name, scope, canonical_name, visibility, errors)
def _add_struct_field_to_scope(field, scope, errors):
- """Adds the name of the given field to the scope."""
- new_scope = _add_name_to_scope_and_normalize(field.name, scope, _Scope.LOCAL,
- errors)
- if field.HasField("abbreviation"):
- _add_name_to_scope(field.abbreviation, scope, new_scope.canonical_name,
- _Scope.PRIVATE, errors)
+ """Adds the name of the given field to the scope."""
+ new_scope = _add_name_to_scope_and_normalize(
+ field.name, scope, _Scope.LOCAL, errors
+ )
+ if field.HasField("abbreviation"):
+ _add_name_to_scope(
+ field.abbreviation, scope, new_scope.canonical_name, _Scope.PRIVATE, errors
+ )
- value_builtin_name = ir_data.Word(
- text="this",
- source_location=ir_data.Location(is_synthetic=True),
- )
- # In "inside field" scope, the name `this` maps back to the field itself.
- # This is important for attributes like `[requires]`.
- _add_name_to_scope(value_builtin_name, new_scope,
- field.name.canonical_name, _Scope.PRIVATE, errors)
+ value_builtin_name = ir_data.Word(
+ text="this",
+ source_location=ir_data.Location(is_synthetic=True),
+ )
+ # In "inside field" scope, the name `this` maps back to the field itself.
+ # This is important for attributes like `[requires]`.
+ _add_name_to_scope(
+ value_builtin_name, new_scope, field.name.canonical_name, _Scope.PRIVATE, errors
+ )
def _add_parameter_name_to_scope(parameter, scope, errors):
- """Adds the name of the given parameter to the scope."""
- _add_name_to_scope_and_normalize(parameter.name, scope, _Scope.LOCAL, errors)
+ """Adds the name of the given parameter to the scope."""
+ _add_name_to_scope_and_normalize(parameter.name, scope, _Scope.LOCAL, errors)
def _add_enum_value_to_scope(value, scope, errors):
- """Adds the name of the enum value to scope."""
- _add_name_to_scope_and_normalize(value.name, scope, _Scope.LOCAL, errors)
+ """Adds the name of the enum value to scope."""
+ _add_name_to_scope_and_normalize(value.name, scope, _Scope.LOCAL, errors)
def _add_type_name_to_scope(type_definition, scope, errors):
- """Adds the name of type_definition to the given scope."""
- new_scope = _add_name_to_scope_and_normalize(type_definition.name, scope,
- _Scope.SEARCHABLE, errors)
- return {"scope": new_scope}
+ """Adds the name of type_definition to the given scope."""
+ new_scope = _add_name_to_scope_and_normalize(
+ type_definition.name, scope, _Scope.SEARCHABLE, errors
+ )
+ return {"scope": new_scope}
def _set_scope_for_type_definition(type_definition, scope):
- """Sets the current scope for an ir_data.AddressableUnit."""
- return {"scope": scope[type_definition.name.name.text]}
+ """Sets the current scope for an ir_data.AddressableUnit."""
+ return {"scope": scope[type_definition.name.name.text]}
def _add_module_to_scope(module, scope):
- """Adds the name of the module to the given scope."""
- module_symbol_table = _Scope(
- ir_data.CanonicalName(module_file=module.source_file_name,
- object_path=[]),
- None,
- _Scope.SEARCHABLE)
- scope[module.source_file_name] = module_symbol_table
- return {"scope": scope[module.source_file_name]}
+ """Adds the name of the module to the given scope."""
+ module_symbol_table = _Scope(
+ ir_data.CanonicalName(module_file=module.source_file_name, object_path=[]),
+ None,
+ _Scope.SEARCHABLE,
+ )
+ scope[module.source_file_name] = module_symbol_table
+ return {"scope": scope[module.source_file_name]}
def _set_scope_for_module(module, scope):
- """Adds the name of the module to the given scope."""
- return {"scope": scope[module.source_file_name]}
+ """Adds the name of the module to the given scope."""
+ return {"scope": scope[module.source_file_name]}
def _add_import_to_scope(foreign_import, table, module, errors):
- if not foreign_import.local_name.text:
- # This is the prelude import; ignore it.
- return
- _add_alias_to_scope(foreign_import.local_name, table, module.canonical_name,
- [foreign_import.file_name.text], _Scope.SEARCHABLE,
- errors)
+ if not foreign_import.local_name.text:
+ # This is the prelude import; ignore it.
+ return
+ _add_alias_to_scope(
+ foreign_import.local_name,
+ table,
+ module.canonical_name,
+ [foreign_import.file_name.text],
+ _Scope.SEARCHABLE,
+ errors,
+ )
def _construct_symbol_tables(ir):
- """Constructs per-module symbol tables for each module in ir."""
- symbol_tables = {}
- errors = []
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Module], _add_module_to_scope,
- parameters={"errors": errors, "scope": symbol_tables})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.TypeDefinition], _add_type_name_to_scope,
- incidental_actions={ir_data.Module: _set_scope_for_module},
- parameters={"errors": errors, "scope": symbol_tables})
- if errors:
- # Ideally, we would find duplicate field names elsewhere in the module, even
- # if there are duplicate type names, but field/enum names in the colliding
- # types also end up colliding, leading to spurious errors. E.g., if you
- # have two `struct Foo`s, then the field check will also discover a
- # collision for `$size_in_bytes`, since there are two `Foo.$size_in_bytes`.
- return symbol_tables, errors
+ """Constructs per-module symbol tables for each module in ir."""
+ symbol_tables = {}
+ errors = []
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Module],
+ _add_module_to_scope,
+ parameters={"errors": errors, "scope": symbol_tables},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.TypeDefinition],
+ _add_type_name_to_scope,
+ incidental_actions={ir_data.Module: _set_scope_for_module},
+ parameters={"errors": errors, "scope": symbol_tables},
+ )
+ if errors:
+ # Ideally, we would find duplicate field names elsewhere in the module, even
+ # if there are duplicate type names, but field/enum names in the colliding
+ # types also end up colliding, leading to spurious errors. E.g., if you
+ # have two `struct Foo`s, then the field check will also discover a
+ # collision for `$size_in_bytes`, since there are two `Foo.$size_in_bytes`.
+ return symbol_tables, errors
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.EnumValue], _add_enum_value_to_scope,
- incidental_actions={
- ir_data.Module: _set_scope_for_module,
- ir_data.TypeDefinition: _set_scope_for_type_definition,
- },
- parameters={"errors": errors, "scope": symbol_tables})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Field], _add_struct_field_to_scope,
- incidental_actions={
- ir_data.Module: _set_scope_for_module,
- ir_data.TypeDefinition: _set_scope_for_type_definition,
- },
- parameters={"errors": errors, "scope": symbol_tables})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.RuntimeParameter], _add_parameter_name_to_scope,
- incidental_actions={
- ir_data.Module: _set_scope_for_module,
- ir_data.TypeDefinition: _set_scope_for_type_definition,
- },
- parameters={"errors": errors, "scope": symbol_tables})
- return symbol_tables, errors
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.EnumValue],
+ _add_enum_value_to_scope,
+ incidental_actions={
+ ir_data.Module: _set_scope_for_module,
+ ir_data.TypeDefinition: _set_scope_for_type_definition,
+ },
+ parameters={"errors": errors, "scope": symbol_tables},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Field],
+ _add_struct_field_to_scope,
+ incidental_actions={
+ ir_data.Module: _set_scope_for_module,
+ ir_data.TypeDefinition: _set_scope_for_type_definition,
+ },
+ parameters={"errors": errors, "scope": symbol_tables},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.RuntimeParameter],
+ _add_parameter_name_to_scope,
+ incidental_actions={
+ ir_data.Module: _set_scope_for_module,
+ ir_data.TypeDefinition: _set_scope_for_type_definition,
+ },
+ parameters={"errors": errors, "scope": symbol_tables},
+ )
+ return symbol_tables, errors
def _add_alias_to_scope(name_ir, table, scope, alias, visibility, errors):
- """Adds the given name to the scope as an alias."""
- name = name_ir.text
- new_scope = _Scope(_nested_name(scope, name), name_ir.source_location,
- visibility, alias)
- scoped_table = table[scope.module_file]
- for path_element in scope.object_path:
- scoped_table = scoped_table[path_element]
- if name in scoped_table:
- errors.append(duplicate_name_error(
- scoped_table.canonical_name.module_file, name_ir.source_location, name,
- FileLocation(scoped_table[name].canonical_name.module_file,
- scoped_table[name].source_location)))
- else:
- scoped_table[name] = new_scope
- return new_scope
-
-
-def _resolve_head_of_field_reference(field_reference, table, current_scope,
- visible_scopes, source_file_name, errors):
- return _resolve_reference(
- field_reference.path[0], table, current_scope,
- visible_scopes, source_file_name, errors)
-
-
-def _resolve_reference(reference, table, current_scope, visible_scopes,
- source_file_name, errors):
- """Sets the canonical name of the given reference."""
- if reference.HasField("canonical_name"):
- # This reference has already been resolved by the _resolve_field_reference
- # pass.
- return
- target = _find_target_of_reference(reference, table, current_scope,
- visible_scopes, source_file_name, errors)
- if target is not None:
- assert not target.alias
- ir_data_utils.builder(reference).canonical_name.CopyFrom(target.canonical_name)
-
-
-def _find_target_of_reference(reference, table, current_scope, visible_scopes,
- source_file_name, errors):
- """Returns the resolved name of the given reference."""
- found_in_table = None
- name = reference.source_name[0].text
- for scope in visible_scopes:
+ """Adds the given name to the scope as an alias."""
+ name = name_ir.text
+ new_scope = _Scope(
+ _nested_name(scope, name), name_ir.source_location, visibility, alias
+ )
scoped_table = table[scope.module_file]
- for path_element in scope.object_path or []:
- scoped_table = scoped_table[path_element]
- if (name in scoped_table and
- (scope == current_scope or
- scoped_table[name].visibility == _Scope.SEARCHABLE)):
- # Prelude is "", so explicitly check for None.
- if found_in_table is not None:
- # TODO(bolms): Currently, this catches the case where a module tries to
- # use a name that is defined (at the same scope) in two different
- # modules. It may make sense to raise duplicate_name_error whenever two
- # modules define the same name (whether it is used or not), and reserve
- # ambiguous_name_error for cases where a name is found in multiple
- # scopes.
- errors.append(ambiguous_name_error(
- source_file_name, reference.source_location, name, [FileLocation(
- found_in_table[name].canonical_name.module_file,
- found_in_table[name].source_location), FileLocation(
- scoped_table[name].canonical_name.module_file, scoped_table[
- name].source_location)]))
- continue
- found_in_table = scoped_table
- if reference.is_local_name:
- # This is a little hacky. When "is_local_name" is True, the name refers
- # to a type that was defined inline. In many cases, the type should be
- # found at the same scope as the field; e.g.:
- #
- # struct Foo:
- # 0 [+1] enum bar:
- # BAZ = 1
- #
- # In this case, `Foo.bar` has type `Foo.Bar`. Unfortunately, things
- # break down a little bit when there is an inline type in an anonymous
- # `bits`:
- #
- # struct Foo:
- # 0 [+1] bits:
- # 0 [+7] enum bar:
- # BAZ = 1
- #
- # Types inside of anonymous `bits` are hoisted into their parent type,
- # so instead of `Foo.EmbossReservedAnonymous1.Bar`, `bar`'s type is just
- # `Foo.Bar`. Unfortunately, the field is still
- # `Foo.EmbossReservedAnonymous1.bar`, so `bar`'s type won't be found in
- # `bar`'s `current_scope`.
- #
- # (The name `bar` is exposed from `Foo` as an alias virtual field, so
- # perhaps the correct answer is to allow type aliases, so that `Bar` can
- # be found in both `Foo` and `Foo.EmbossReservedAnonymous1`. That would
- # involve an entirely new feature, though.)
- #
- # The workaround here is to search scopes from the innermost outward,
- # and just stop as soon as a match is found. This isn't ideal, because
- # it relies on other bits of the front end having correctly added the
- # inline type to the correct scope before symbol resolution, but it does
- # work. Names with False `is_local_name` will still be checked for
- # ambiguity.
- break
- if found_in_table is None:
- errors.append(missing_name_error(
- source_file_name, reference.source_name[0].source_location, name))
- if not errors:
- for subname in reference.source_name:
- if subname.text not in found_in_table:
- errors.append(missing_name_error(source_file_name,
- subname.source_location, subname.text))
- return None
- found_in_table = found_in_table[subname.text]
- while found_in_table.alias:
- referenced_table = table
- for name in found_in_table.alias:
- referenced_table = referenced_table[name]
- # TODO(bolms): This section should really be a recursive lookup
- # function, which would be able to handle arbitrary aliases through
- # other aliases.
- #
- # This should be fine for now, since the only aliases here should be
- # imports, which can't refer to other imports.
- assert not referenced_table.alias, "Alias found to contain alias."
- found_in_table = referenced_table
- return found_in_table
- return None
+ for path_element in scope.object_path:
+ scoped_table = scoped_table[path_element]
+ if name in scoped_table:
+ errors.append(
+ duplicate_name_error(
+ scoped_table.canonical_name.module_file,
+ name_ir.source_location,
+ name,
+ FileLocation(
+ scoped_table[name].canonical_name.module_file,
+ scoped_table[name].source_location,
+ ),
+ )
+ )
+ else:
+ scoped_table[name] = new_scope
+ return new_scope
+
+
+def _resolve_head_of_field_reference(
+ field_reference, table, current_scope, visible_scopes, source_file_name, errors
+):
+ return _resolve_reference(
+ field_reference.path[0],
+ table,
+ current_scope,
+ visible_scopes,
+ source_file_name,
+ errors,
+ )
+
+
+def _resolve_reference(
+ reference, table, current_scope, visible_scopes, source_file_name, errors
+):
+ """Sets the canonical name of the given reference."""
+ if reference.HasField("canonical_name"):
+ # This reference has already been resolved by the _resolve_field_reference
+ # pass.
+ return
+ target = _find_target_of_reference(
+ reference, table, current_scope, visible_scopes, source_file_name, errors
+ )
+ if target is not None:
+ assert not target.alias
+ ir_data_utils.builder(reference).canonical_name.CopyFrom(target.canonical_name)
+
+
+def _find_target_of_reference(
+ reference, table, current_scope, visible_scopes, source_file_name, errors
+):
+ """Returns the resolved name of the given reference."""
+ found_in_table = None
+ name = reference.source_name[0].text
+ for scope in visible_scopes:
+ scoped_table = table[scope.module_file]
+ for path_element in scope.object_path or []:
+ scoped_table = scoped_table[path_element]
+ if name in scoped_table and (
+ scope == current_scope or scoped_table[name].visibility == _Scope.SEARCHABLE
+ ):
+ # Prelude is "", so explicitly check for None.
+ if found_in_table is not None:
+ # TODO(bolms): Currently, this catches the case where a module tries to
+ # use a name that is defined (at the same scope) in two different
+ # modules. It may make sense to raise duplicate_name_error whenever two
+ # modules define the same name (whether it is used or not), and reserve
+ # ambiguous_name_error for cases where a name is found in multiple
+ # scopes.
+ errors.append(
+ ambiguous_name_error(
+ source_file_name,
+ reference.source_location,
+ name,
+ [
+ FileLocation(
+ found_in_table[name].canonical_name.module_file,
+ found_in_table[name].source_location,
+ ),
+ FileLocation(
+ scoped_table[name].canonical_name.module_file,
+ scoped_table[name].source_location,
+ ),
+ ],
+ )
+ )
+ continue
+ found_in_table = scoped_table
+ if reference.is_local_name:
+ # This is a little hacky. When "is_local_name" is True, the name refers
+ # to a type that was defined inline. In many cases, the type should be
+ # found at the same scope as the field; e.g.:
+ #
+ # struct Foo:
+ # 0 [+1] enum bar:
+ # BAZ = 1
+ #
+ # In this case, `Foo.bar` has type `Foo.Bar`. Unfortunately, things
+ # break down a little bit when there is an inline type in an anonymous
+ # `bits`:
+ #
+ # struct Foo:
+ # 0 [+1] bits:
+ # 0 [+7] enum bar:
+ # BAZ = 1
+ #
+ # Types inside of anonymous `bits` are hoisted into their parent type,
+ # so instead of `Foo.EmbossReservedAnonymous1.Bar`, `bar`'s type is just
+ # `Foo.Bar`. Unfortunately, the field is still
+ # `Foo.EmbossReservedAnonymous1.bar`, so `bar`'s type won't be found in
+ # `bar`'s `current_scope`.
+ #
+ # (The name `bar` is exposed from `Foo` as an alias virtual field, so
+ # perhaps the correct answer is to allow type aliases, so that `Bar` can
+ # be found in both `Foo` and `Foo.EmbossReservedAnonymous1`. That would
+ # involve an entirely new feature, though.)
+ #
+ # The workaround here is to search scopes from the innermost outward,
+ # and just stop as soon as a match is found. This isn't ideal, because
+ # it relies on other bits of the front end having correctly added the
+ # inline type to the correct scope before symbol resolution, but it does
+ # work. Names with False `is_local_name` will still be checked for
+ # ambiguity.
+ break
+ if found_in_table is None:
+ errors.append(
+ missing_name_error(
+ source_file_name, reference.source_name[0].source_location, name
+ )
+ )
+ if not errors:
+ for subname in reference.source_name:
+ if subname.text not in found_in_table:
+ errors.append(
+ missing_name_error(
+ source_file_name, subname.source_location, subname.text
+ )
+ )
+ return None
+ found_in_table = found_in_table[subname.text]
+ while found_in_table.alias:
+ referenced_table = table
+ for name in found_in_table.alias:
+ referenced_table = referenced_table[name]
+ # TODO(bolms): This section should really be a recursive lookup
+ # function, which would be able to handle arbitrary aliases through
+ # other aliases.
+ #
+ # This should be fine for now, since the only aliases here should be
+ # imports, which can't refer to other imports.
+ assert not referenced_table.alias, "Alias found to contain alias."
+ found_in_table = referenced_table
+ return found_in_table
+ return None
def _resolve_field_reference(field_reference, source_file_name, errors, ir):
- """Resolves the References inside of a FieldReference."""
- if field_reference.path[-1].HasField("canonical_name"):
- # Already done.
- return
- previous_field = ir_util.find_object_or_none(field_reference.path[0], ir)
- previous_reference = field_reference.path[0]
- for ref in field_reference.path[1:]:
- while ir_util.field_is_virtual(previous_field):
- if (previous_field.read_transform.WhichOneof("expression") ==
- "field_reference"):
- # Pass a separate error list into the recursive _resolve_field_reference
- # call so that only one copy of the error for a particular reference
- # will actually surface: in particular, the one that results from a
- # direct call from traverse_ir_top_down into _resolve_field_reference.
- new_errors = []
- _resolve_field_reference(
- previous_field.read_transform.field_reference,
- previous_field.name.canonical_name.module_file, new_errors, ir)
- # If the recursive _resolve_field_reference was unable to resolve the
- # field, then bail. Otherwise we get a cascade of errors, where an
- # error in `x` leads to errors in anything trying to reach a member of
- # `x`.
- if not previous_field.read_transform.field_reference.path[-1].HasField(
- "canonical_name"):
- return
- previous_field = ir_util.find_object(
- previous_field.read_transform.field_reference.path[-1], ir)
- else:
- errors.append(
- noncomposite_subfield_error(source_file_name,
- previous_reference.source_location,
- previous_reference.source_name[0].text))
+ """Resolves the References inside of a FieldReference."""
+ if field_reference.path[-1].HasField("canonical_name"):
+ # Already done.
return
- if previous_field.type.WhichOneof("type") == "array_type":
- errors.append(
- array_subfield_error(source_file_name,
- previous_reference.source_location,
- previous_reference.source_name[0].text))
- return
- assert previous_field.type.WhichOneof("type") == "atomic_type"
- member_name = ir_data_utils.copy(
- previous_field.type.atomic_type.reference.canonical_name)
- ir_data_utils.builder(member_name).object_path.extend([ref.source_name[0].text])
- previous_field = ir_util.find_object_or_none(member_name, ir)
- if previous_field is None:
- errors.append(
- missing_name_error(source_file_name,
- ref.source_name[0].source_location,
- ref.source_name[0].text))
- return
- ir_data_utils.builder(ref).canonical_name.CopyFrom(member_name)
- previous_reference = ref
+ previous_field = ir_util.find_object_or_none(field_reference.path[0], ir)
+ previous_reference = field_reference.path[0]
+ for ref in field_reference.path[1:]:
+ while ir_util.field_is_virtual(previous_field):
+ if (
+ previous_field.read_transform.WhichOneof("expression")
+ == "field_reference"
+ ):
+ # Pass a separate error list into the recursive _resolve_field_reference
+ # call so that only one copy of the error for a particular reference
+ # will actually surface: in particular, the one that results from a
+ # direct call from traverse_ir_top_down into _resolve_field_reference.
+ new_errors = []
+ _resolve_field_reference(
+ previous_field.read_transform.field_reference,
+ previous_field.name.canonical_name.module_file,
+ new_errors,
+ ir,
+ )
+ # If the recursive _resolve_field_reference was unable to resolve the
+ # field, then bail. Otherwise we get a cascade of errors, where an
+ # error in `x` leads to errors in anything trying to reach a member of
+ # `x`.
+ if not previous_field.read_transform.field_reference.path[-1].HasField(
+ "canonical_name"
+ ):
+ return
+ previous_field = ir_util.find_object(
+ previous_field.read_transform.field_reference.path[-1], ir
+ )
+ else:
+ errors.append(
+ noncomposite_subfield_error(
+ source_file_name,
+ previous_reference.source_location,
+ previous_reference.source_name[0].text,
+ )
+ )
+ return
+ if previous_field.type.WhichOneof("type") == "array_type":
+ errors.append(
+ array_subfield_error(
+ source_file_name,
+ previous_reference.source_location,
+ previous_reference.source_name[0].text,
+ )
+ )
+ return
+ assert previous_field.type.WhichOneof("type") == "atomic_type"
+ member_name = ir_data_utils.copy(
+ previous_field.type.atomic_type.reference.canonical_name
+ )
+ ir_data_utils.builder(member_name).object_path.extend([ref.source_name[0].text])
+ previous_field = ir_util.find_object_or_none(member_name, ir)
+ if previous_field is None:
+ errors.append(
+ missing_name_error(
+ source_file_name,
+ ref.source_name[0].source_location,
+ ref.source_name[0].text,
+ )
+ )
+ return
+ ir_data_utils.builder(ref).canonical_name.CopyFrom(member_name)
+ previous_reference = ref
def _set_visible_scopes_for_type_definition(type_definition, visible_scopes):
- """Sets current_scope and visible_scopes for the given type_definition."""
- return {
- "current_scope": type_definition.name.canonical_name,
-
- # In order to ensure that the iteration through scopes in
- # _find_target_of_reference will go from innermost to outermost, it is
- # important that the current scope (type_definition.name.canonical_name)
- # precedes the previous visible_scopes here.
- "visible_scopes": (type_definition.name.canonical_name,) + visible_scopes,
- }
+ """Sets current_scope and visible_scopes for the given type_definition."""
+ return {
+ "current_scope": type_definition.name.canonical_name,
+ # In order to ensure that the iteration through scopes in
+ # _find_target_of_reference will go from innermost to outermost, it is
+ # important that the current scope (type_definition.name.canonical_name)
+ # precedes the previous visible_scopes here.
+ "visible_scopes": (type_definition.name.canonical_name,) + visible_scopes,
+ }
def _set_visible_scopes_for_module(module):
- """Sets visible_scopes for the given module."""
- self_scope = ir_data.CanonicalName(module_file=module.source_file_name)
- extra_visible_scopes = []
- for foreign_import in module.foreign_import:
- # Anonymous imports are searched for top-level names; named imports are not.
- # As of right now, only the prelude should be imported anonymously; other
- # modules must be imported with names.
- if not foreign_import.local_name.text:
- extra_visible_scopes.append(
- ir_data.CanonicalName(module_file=foreign_import.file_name.text))
- return {"visible_scopes": (self_scope,) + tuple(extra_visible_scopes)}
+ """Sets visible_scopes for the given module."""
+ self_scope = ir_data.CanonicalName(module_file=module.source_file_name)
+ extra_visible_scopes = []
+ for foreign_import in module.foreign_import:
+ # Anonymous imports are searched for top-level names; named imports are not.
+ # As of right now, only the prelude should be imported anonymously; other
+ # modules must be imported with names.
+ if not foreign_import.local_name.text:
+ extra_visible_scopes.append(
+ ir_data.CanonicalName(module_file=foreign_import.file_name.text)
+ )
+ return {"visible_scopes": (self_scope,) + tuple(extra_visible_scopes)}
def _set_visible_scopes_for_attribute(attribute, field, visible_scopes):
- """Sets current_scope and visible_scopes for the attribute."""
- del attribute # Unused
- if field is None:
- return
- return {
- "current_scope": field.name.canonical_name,
- "visible_scopes": (field.name.canonical_name,) + visible_scopes,
- }
+ """Sets current_scope and visible_scopes for the attribute."""
+ del attribute # Unused
+ if field is None:
+ return
+ return {
+ "current_scope": field.name.canonical_name,
+ "visible_scopes": (field.name.canonical_name,) + visible_scopes,
+ }
+
def _module_source_from_table_action(m, table):
- return {"module": table[m.source_file_name]}
+ return {"module": table[m.source_file_name]}
+
def _resolve_symbols_from_table(ir, table):
- """Resolves all references in the given IR, given the constructed table."""
- errors = []
- # Symbol resolution is broken into five passes. First, this code resolves any
- # imports, and adds import aliases to modules.
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Import], _add_import_to_scope,
- incidental_actions={
- ir_data.Module: _module_source_from_table_action,
- },
- parameters={"errors": errors, "table": table})
- if errors:
+ """Resolves all references in the given IR, given the constructed table."""
+ errors = []
+ # Symbol resolution is broken into five passes. First, this code resolves any
+ # imports, and adds import aliases to modules.
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Import],
+ _add_import_to_scope,
+ incidental_actions={
+ ir_data.Module: _module_source_from_table_action,
+ },
+ parameters={"errors": errors, "table": table},
+ )
+ if errors:
+ return errors
+ # Next, this resolves all absolute references (e.g., it resolves "UInt" in
+ # "0:1 UInt field" to [prelude]::UInt).
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Reference],
+ _resolve_reference,
+ skip_descendants_of=(ir_data.FieldReference,),
+ incidental_actions={
+ ir_data.TypeDefinition: _set_visible_scopes_for_type_definition,
+ ir_data.Module: _set_visible_scopes_for_module,
+ ir_data.Attribute: _set_visible_scopes_for_attribute,
+ },
+ parameters={"table": table, "errors": errors, "field": None},
+ )
+ # Lastly, head References to fields (e.g., the `a` of `a.b.c`) are resolved.
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.FieldReference],
+ _resolve_head_of_field_reference,
+ incidental_actions={
+ ir_data.TypeDefinition: _set_visible_scopes_for_type_definition,
+ ir_data.Module: _set_visible_scopes_for_module,
+ ir_data.Attribute: _set_visible_scopes_for_attribute,
+ },
+ parameters={"table": table, "errors": errors, "field": None},
+ )
return errors
- # Next, this resolves all absolute references (e.g., it resolves "UInt" in
- # "0:1 UInt field" to [prelude]::UInt).
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Reference], _resolve_reference,
- skip_descendants_of=(ir_data.FieldReference,),
- incidental_actions={
- ir_data.TypeDefinition: _set_visible_scopes_for_type_definition,
- ir_data.Module: _set_visible_scopes_for_module,
- ir_data.Attribute: _set_visible_scopes_for_attribute,
- },
- parameters={"table": table, "errors": errors, "field": None})
- # Lastly, head References to fields (e.g., the `a` of `a.b.c`) are resolved.
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.FieldReference], _resolve_head_of_field_reference,
- incidental_actions={
- ir_data.TypeDefinition: _set_visible_scopes_for_type_definition,
- ir_data.Module: _set_visible_scopes_for_module,
- ir_data.Attribute: _set_visible_scopes_for_attribute,
- },
- parameters={"table": table, "errors": errors, "field": None})
- return errors
def resolve_field_references(ir):
- """Resolves structure member accesses ("field.subfield") in ir."""
- errors = []
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.FieldReference], _resolve_field_reference,
- incidental_actions={
- ir_data.TypeDefinition: _set_visible_scopes_for_type_definition,
- ir_data.Module: _set_visible_scopes_for_module,
- ir_data.Attribute: _set_visible_scopes_for_attribute,
- },
- parameters={"errors": errors, "field": None})
- return errors
+ """Resolves structure member accesses ("field.subfield") in ir."""
+ errors = []
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.FieldReference],
+ _resolve_field_reference,
+ incidental_actions={
+ ir_data.TypeDefinition: _set_visible_scopes_for_type_definition,
+ ir_data.Module: _set_visible_scopes_for_module,
+ ir_data.Attribute: _set_visible_scopes_for_attribute,
+ },
+ parameters={"errors": errors, "field": None},
+ )
+ return errors
def resolve_symbols(ir):
- """Resolves the symbols in all modules in ir."""
- symbol_tables, errors = _construct_symbol_tables(ir)
- if errors:
- return errors
- return _resolve_symbols_from_table(ir, symbol_tables)
+ """Resolves the symbols in all modules in ir."""
+ symbol_tables, errors = _construct_symbol_tables(ir)
+ if errors:
+ return errors
+ return _resolve_symbols_from_table(ir, symbol_tables)
diff --git a/compiler/front_end/symbol_resolver_test.py b/compiler/front_end/symbol_resolver_test.py
index deaf1a0..693d157 100644
--- a/compiler/front_end/symbol_resolver_test.py
+++ b/compiler/front_end/symbol_resolver_test.py
@@ -51,698 +51,969 @@
class ResolveSymbolsTest(unittest.TestCase):
- """Tests for symbol_resolver.resolve_symbols()."""
+ """Tests for symbol_resolver.resolve_symbols()."""
- def _construct_ir_multiple(self, file_dict, primary_emb_name):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- primary_emb_name,
- test_util.dict_file_reader(file_dict),
- stop_before_step="resolve_symbols")
- assert not errors
- return ir
+ def _construct_ir_multiple(self, file_dict, primary_emb_name):
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ primary_emb_name,
+ test_util.dict_file_reader(file_dict),
+ stop_before_step="resolve_symbols",
+ )
+ assert not errors
+ return ir
- def _construct_ir(self, emb_text, name="happy.emb"):
- return self._construct_ir_multiple({name: emb_text}, name)
+ def _construct_ir(self, emb_text, name="happy.emb"):
+ return self._construct_ir_multiple({name: emb_text}, name)
- def test_struct_field_atomic_type_resolution(self):
- ir = self._construct_ir(_HAPPY_EMB)
- self.assertEqual([], symbol_resolver.resolve_symbols(ir))
- struct_ir = ir.module[0].type[0].structure
- atomic_field1_reference = struct_ir.field[0].type.atomic_type.reference
- self.assertEqual(atomic_field1_reference.canonical_name.object_path, ["UInt"
- ])
- self.assertEqual(atomic_field1_reference.canonical_name.module_file, "")
- atomic_field2_reference = struct_ir.field[1].type.atomic_type.reference
- self.assertEqual(atomic_field2_reference.canonical_name.object_path, ["Bar"
- ])
- self.assertEqual(atomic_field2_reference.canonical_name.module_file,
- "happy.emb")
+ def test_struct_field_atomic_type_resolution(self):
+ ir = self._construct_ir(_HAPPY_EMB)
+ self.assertEqual([], symbol_resolver.resolve_symbols(ir))
+ struct_ir = ir.module[0].type[0].structure
+ atomic_field1_reference = struct_ir.field[0].type.atomic_type.reference
+ self.assertEqual(atomic_field1_reference.canonical_name.object_path, ["UInt"])
+ self.assertEqual(atomic_field1_reference.canonical_name.module_file, "")
+ atomic_field2_reference = struct_ir.field[1].type.atomic_type.reference
+ self.assertEqual(atomic_field2_reference.canonical_name.object_path, ["Bar"])
+ self.assertEqual(
+ atomic_field2_reference.canonical_name.module_file, "happy.emb"
+ )
- def test_struct_field_enum_type_resolution(self):
- ir = self._construct_ir(_HAPPY_EMB)
- self.assertEqual([], symbol_resolver.resolve_symbols(ir))
- struct_ir = ir.module[0].type[1].structure
- atomic_field_reference = struct_ir.field[0].type.atomic_type.reference
- self.assertEqual(atomic_field_reference.canonical_name.object_path, ["Qux"])
- self.assertEqual(atomic_field_reference.canonical_name.module_file,
- "happy.emb")
+ def test_struct_field_enum_type_resolution(self):
+ ir = self._construct_ir(_HAPPY_EMB)
+ self.assertEqual([], symbol_resolver.resolve_symbols(ir))
+ struct_ir = ir.module[0].type[1].structure
+ atomic_field_reference = struct_ir.field[0].type.atomic_type.reference
+ self.assertEqual(atomic_field_reference.canonical_name.object_path, ["Qux"])
+ self.assertEqual(atomic_field_reference.canonical_name.module_file, "happy.emb")
- def test_struct_field_array_type_resolution(self):
- ir = self._construct_ir(_HAPPY_EMB)
- self.assertEqual([], symbol_resolver.resolve_symbols(ir))
- array_field_type = ir.module[0].type[0].structure.field[2].type.array_type
- array_field_reference = array_field_type.base_type.atomic_type.reference
- self.assertEqual(array_field_reference.canonical_name.object_path, ["UInt"])
- self.assertEqual(array_field_reference.canonical_name.module_file, "")
+ def test_struct_field_array_type_resolution(self):
+ ir = self._construct_ir(_HAPPY_EMB)
+ self.assertEqual([], symbol_resolver.resolve_symbols(ir))
+ array_field_type = ir.module[0].type[0].structure.field[2].type.array_type
+ array_field_reference = array_field_type.base_type.atomic_type.reference
+ self.assertEqual(array_field_reference.canonical_name.object_path, ["UInt"])
+ self.assertEqual(array_field_reference.canonical_name.module_file, "")
- def test_inner_type_resolution(self):
- ir = self._construct_ir(_HAPPY_EMB)
- self.assertEqual([], symbol_resolver.resolve_symbols(ir))
- array_field_type = ir.module[0].type[0].structure.field[2].type.array_type
- array_field_reference = array_field_type.base_type.atomic_type.reference
- self.assertEqual(array_field_reference.canonical_name.object_path, ["UInt"])
- self.assertEqual(array_field_reference.canonical_name.module_file, "")
+ def test_inner_type_resolution(self):
+ ir = self._construct_ir(_HAPPY_EMB)
+ self.assertEqual([], symbol_resolver.resolve_symbols(ir))
+ array_field_type = ir.module[0].type[0].structure.field[2].type.array_type
+ array_field_reference = array_field_type.base_type.atomic_type.reference
+ self.assertEqual(array_field_reference.canonical_name.object_path, ["UInt"])
+ self.assertEqual(array_field_reference.canonical_name.module_file, "")
- def test_struct_field_resolution_in_expression_in_location(self):
- ir = self._construct_ir(_HAPPY_EMB)
- self.assertEqual([], symbol_resolver.resolve_symbols(ir))
- struct_ir = ir.module[0].type[3].structure
- field0_loc = struct_ir.field[0].location
- abbreviation_reference = field0_loc.size.field_reference.path[0]
- self.assertEqual(abbreviation_reference.canonical_name.object_path,
- ["FieldRef", "offset"])
- self.assertEqual(abbreviation_reference.canonical_name.module_file,
- "happy.emb")
- field0_start_left = field0_loc.start.function.args[0]
- nested_abbreviation_reference = field0_start_left.field_reference.path[0]
- self.assertEqual(nested_abbreviation_reference.canonical_name.object_path,
- ["FieldRef", "offset"])
- self.assertEqual(nested_abbreviation_reference.canonical_name.module_file,
- "happy.emb")
- field1_loc = struct_ir.field[1].location
- direct_reference = field1_loc.size.field_reference.path[0]
- self.assertEqual(direct_reference.canonical_name.object_path, ["FieldRef",
- "offset"])
- self.assertEqual(direct_reference.canonical_name.module_file, "happy.emb")
- field1_start_left = field1_loc.start.function.args[0]
- nested_direct_reference = field1_start_left.field_reference.path[0]
- self.assertEqual(nested_direct_reference.canonical_name.object_path,
- ["FieldRef", "offset"])
- self.assertEqual(nested_direct_reference.canonical_name.module_file,
- "happy.emb")
+ def test_struct_field_resolution_in_expression_in_location(self):
+ ir = self._construct_ir(_HAPPY_EMB)
+ self.assertEqual([], symbol_resolver.resolve_symbols(ir))
+ struct_ir = ir.module[0].type[3].structure
+ field0_loc = struct_ir.field[0].location
+ abbreviation_reference = field0_loc.size.field_reference.path[0]
+ self.assertEqual(
+ abbreviation_reference.canonical_name.object_path, ["FieldRef", "offset"]
+ )
+ self.assertEqual(abbreviation_reference.canonical_name.module_file, "happy.emb")
+ field0_start_left = field0_loc.start.function.args[0]
+ nested_abbreviation_reference = field0_start_left.field_reference.path[0]
+ self.assertEqual(
+ nested_abbreviation_reference.canonical_name.object_path,
+ ["FieldRef", "offset"],
+ )
+ self.assertEqual(
+ nested_abbreviation_reference.canonical_name.module_file, "happy.emb"
+ )
+ field1_loc = struct_ir.field[1].location
+ direct_reference = field1_loc.size.field_reference.path[0]
+ self.assertEqual(
+ direct_reference.canonical_name.object_path, ["FieldRef", "offset"]
+ )
+ self.assertEqual(direct_reference.canonical_name.module_file, "happy.emb")
+ field1_start_left = field1_loc.start.function.args[0]
+ nested_direct_reference = field1_start_left.field_reference.path[0]
+ self.assertEqual(
+ nested_direct_reference.canonical_name.object_path, ["FieldRef", "offset"]
+ )
+ self.assertEqual(
+ nested_direct_reference.canonical_name.module_file, "happy.emb"
+ )
- def test_struct_field_resolution_in_expression_in_array_length(self):
- ir = self._construct_ir(_HAPPY_EMB)
- self.assertEqual([], symbol_resolver.resolve_symbols(ir))
- struct_ir = ir.module[0].type[3].structure
- field0_array_type = struct_ir.field[0].type.array_type
- field0_array_element_count = field0_array_type.element_count
- abbreviation_reference = field0_array_element_count.field_reference.path[0]
- self.assertEqual(abbreviation_reference.canonical_name.object_path,
- ["FieldRef", "offset"])
- self.assertEqual(abbreviation_reference.canonical_name.module_file,
- "happy.emb")
- field1_array_type = struct_ir.field[1].type.array_type
- direct_reference = field1_array_type.element_count.field_reference.path[0]
- self.assertEqual(direct_reference.canonical_name.object_path, ["FieldRef",
- "offset"])
- self.assertEqual(direct_reference.canonical_name.module_file, "happy.emb")
+ def test_struct_field_resolution_in_expression_in_array_length(self):
+ ir = self._construct_ir(_HAPPY_EMB)
+ self.assertEqual([], symbol_resolver.resolve_symbols(ir))
+ struct_ir = ir.module[0].type[3].structure
+ field0_array_type = struct_ir.field[0].type.array_type
+ field0_array_element_count = field0_array_type.element_count
+ abbreviation_reference = field0_array_element_count.field_reference.path[0]
+ self.assertEqual(
+ abbreviation_reference.canonical_name.object_path, ["FieldRef", "offset"]
+ )
+ self.assertEqual(abbreviation_reference.canonical_name.module_file, "happy.emb")
+ field1_array_type = struct_ir.field[1].type.array_type
+ direct_reference = field1_array_type.element_count.field_reference.path[0]
+ self.assertEqual(
+ direct_reference.canonical_name.object_path, ["FieldRef", "offset"]
+ )
+ self.assertEqual(direct_reference.canonical_name.module_file, "happy.emb")
- def test_struct_parameter_resolution(self):
- ir = self._construct_ir(_HAPPY_EMB)
- self.assertEqual([], symbol_resolver.resolve_symbols(ir))
- struct_ir = ir.module[0].type[6].structure
- size_ir = struct_ir.field[0].location.size
- self.assertTrue(size_ir.HasField("field_reference"))
- self.assertEqual(size_ir.field_reference.path[0].canonical_name.object_path,
- ["UsesParameter", "x"])
+ def test_struct_parameter_resolution(self):
+ ir = self._construct_ir(_HAPPY_EMB)
+ self.assertEqual([], symbol_resolver.resolve_symbols(ir))
+ struct_ir = ir.module[0].type[6].structure
+ size_ir = struct_ir.field[0].location.size
+ self.assertTrue(size_ir.HasField("field_reference"))
+ self.assertEqual(
+ size_ir.field_reference.path[0].canonical_name.object_path,
+ ["UsesParameter", "x"],
+ )
- def test_enum_value_resolution_in_expression_in_enum_field(self):
- ir = self._construct_ir(_HAPPY_EMB)
- self.assertEqual([], symbol_resolver.resolve_symbols(ir))
- enum_ir = ir.module[0].type[5].enumeration
- value_reference = enum_ir.value[1].value.constant_reference
- self.assertEqual(value_reference.canonical_name.object_path,
- ["Quux", "ABC"])
- self.assertEqual(value_reference.canonical_name.module_file, "happy.emb")
+ def test_enum_value_resolution_in_expression_in_enum_field(self):
+ ir = self._construct_ir(_HAPPY_EMB)
+ self.assertEqual([], symbol_resolver.resolve_symbols(ir))
+ enum_ir = ir.module[0].type[5].enumeration
+ value_reference = enum_ir.value[1].value.constant_reference
+ self.assertEqual(value_reference.canonical_name.object_path, ["Quux", "ABC"])
+ self.assertEqual(value_reference.canonical_name.module_file, "happy.emb")
- def test_symbol_resolution_in_expression_in_void_array_length(self):
- ir = self._construct_ir(_HAPPY_EMB)
- self.assertEqual([], symbol_resolver.resolve_symbols(ir))
- struct_ir = ir.module[0].type[4].structure
- array_type = struct_ir.field[0].type.array_type
- # The symbol resolver should ignore void fields.
- self.assertEqual("automatic", array_type.WhichOneof("size"))
+ def test_symbol_resolution_in_expression_in_void_array_length(self):
+ ir = self._construct_ir(_HAPPY_EMB)
+ self.assertEqual([], symbol_resolver.resolve_symbols(ir))
+ struct_ir = ir.module[0].type[4].structure
+ array_type = struct_ir.field[0].type.array_type
+ # The symbol resolver should ignore void fields.
+ self.assertEqual("automatic", array_type.WhichOneof("size"))
- def test_name_definitions_have_correct_canonical_names(self):
- ir = self._construct_ir(_HAPPY_EMB)
- self.assertEqual([], symbol_resolver.resolve_symbols(ir))
- foo_name = ir.module[0].type[0].name
- self.assertEqual(foo_name.canonical_name.object_path, ["Foo"])
- self.assertEqual(foo_name.canonical_name.module_file, "happy.emb")
- uint_field_name = ir.module[0].type[0].structure.field[0].name
- self.assertEqual(uint_field_name.canonical_name.object_path, ["Foo",
- "uint_field"])
- self.assertEqual(uint_field_name.canonical_name.module_file, "happy.emb")
- foo_name = ir.module[0].type[2].name
- self.assertEqual(foo_name.canonical_name.object_path, ["Qux"])
- self.assertEqual(foo_name.canonical_name.module_file, "happy.emb")
+ def test_name_definitions_have_correct_canonical_names(self):
+ ir = self._construct_ir(_HAPPY_EMB)
+ self.assertEqual([], symbol_resolver.resolve_symbols(ir))
+ foo_name = ir.module[0].type[0].name
+ self.assertEqual(foo_name.canonical_name.object_path, ["Foo"])
+ self.assertEqual(foo_name.canonical_name.module_file, "happy.emb")
+ uint_field_name = ir.module[0].type[0].structure.field[0].name
+ self.assertEqual(
+ uint_field_name.canonical_name.object_path, ["Foo", "uint_field"]
+ )
+ self.assertEqual(uint_field_name.canonical_name.module_file, "happy.emb")
+ foo_name = ir.module[0].type[2].name
+ self.assertEqual(foo_name.canonical_name.object_path, ["Qux"])
+ self.assertEqual(foo_name.canonical_name.module_file, "happy.emb")
- def test_duplicate_type_name(self):
- ir = self._construct_ir("struct Foo:\n"
- " 0 [+4] UInt field\n"
- "struct Foo:\n"
- " 0 [+4] UInt bar\n", "duplicate_type.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- self.assertEqual([
- [error.error("duplicate_type.emb",
- ir.module[0].type[1].name.source_location,
- "Duplicate name 'Foo'"),
- error.note("duplicate_type.emb",
- ir.module[0].type[0].name.source_location,
- "Original definition")]
- ], errors)
+ def test_duplicate_type_name(self):
+ ir = self._construct_ir(
+ "struct Foo:\n"
+ " 0 [+4] UInt field\n"
+ "struct Foo:\n"
+ " 0 [+4] UInt bar\n",
+ "duplicate_type.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "duplicate_type.emb",
+ ir.module[0].type[1].name.source_location,
+ "Duplicate name 'Foo'",
+ ),
+ error.note(
+ "duplicate_type.emb",
+ ir.module[0].type[0].name.source_location,
+ "Original definition",
+ ),
+ ]
+ ],
+ errors,
+ )
- def test_duplicate_field_name_in_struct(self):
- ir = self._construct_ir("struct Foo:\n"
- " 0 [+4] UInt field\n"
- " 4 [+4] UInt field\n", "duplicate_field.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("duplicate_field.emb",
- struct.field[1].name.source_location,
- "Duplicate name 'field'"),
- error.note("duplicate_field.emb",
- struct.field[0].name.source_location,
- "Original definition")
- ]], errors)
+ def test_duplicate_field_name_in_struct(self):
+ ir = self._construct_ir(
+ "struct Foo:\n" " 0 [+4] UInt field\n" " 4 [+4] UInt field\n",
+ "duplicate_field.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "duplicate_field.emb",
+ struct.field[1].name.source_location,
+ "Duplicate name 'field'",
+ ),
+ error.note(
+ "duplicate_field.emb",
+ struct.field[0].name.source_location,
+ "Original definition",
+ ),
+ ]
+ ],
+ errors,
+ )
- def test_duplicate_abbreviation_in_struct(self):
- ir = self._construct_ir("struct Foo:\n"
- " 0 [+4] UInt field1 (f)\n"
- " 4 [+4] UInt field2 (f)\n",
- "duplicate_field.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("duplicate_field.emb",
- struct.field[1].abbreviation.source_location,
- "Duplicate name 'f'"),
- error.note("duplicate_field.emb",
- struct.field[0].abbreviation.source_location,
- "Original definition")
- ]], errors)
+ def test_duplicate_abbreviation_in_struct(self):
+ ir = self._construct_ir(
+ "struct Foo:\n"
+ " 0 [+4] UInt field1 (f)\n"
+ " 4 [+4] UInt field2 (f)\n",
+ "duplicate_field.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "duplicate_field.emb",
+ struct.field[1].abbreviation.source_location,
+ "Duplicate name 'f'",
+ ),
+ error.note(
+ "duplicate_field.emb",
+ struct.field[0].abbreviation.source_location,
+ "Original definition",
+ ),
+ ]
+ ],
+ errors,
+ )
- def test_abbreviation_duplicates_field_name_in_struct(self):
- ir = self._construct_ir("struct Foo:\n"
- " 0 [+4] UInt field\n"
- " 4 [+4] UInt field2 (field)\n",
- "duplicate_field.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("duplicate_field.emb",
- struct.field[1].abbreviation.source_location,
- "Duplicate name 'field'"),
- error.note("duplicate_field.emb",
- struct.field[0].name.source_location,
- "Original definition")
- ]], errors)
+ def test_abbreviation_duplicates_field_name_in_struct(self):
+ ir = self._construct_ir(
+ "struct Foo:\n"
+ " 0 [+4] UInt field\n"
+ " 4 [+4] UInt field2 (field)\n",
+ "duplicate_field.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "duplicate_field.emb",
+ struct.field[1].abbreviation.source_location,
+ "Duplicate name 'field'",
+ ),
+ error.note(
+ "duplicate_field.emb",
+ struct.field[0].name.source_location,
+ "Original definition",
+ ),
+ ]
+ ],
+ errors,
+ )
- def test_field_name_duplicates_abbreviation_in_struct(self):
- ir = self._construct_ir("struct Foo:\n"
- " 0 [+4] UInt field (field2)\n"
- " 4 [+4] UInt field2\n", "duplicate_field.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("duplicate_field.emb",
- struct.field[1].name.source_location,
- "Duplicate name 'field2'"),
- error.note("duplicate_field.emb",
- struct.field[0].abbreviation.source_location,
- "Original definition")
- ]], errors)
+ def test_field_name_duplicates_abbreviation_in_struct(self):
+ ir = self._construct_ir(
+ "struct Foo:\n"
+ " 0 [+4] UInt field (field2)\n"
+ " 4 [+4] UInt field2\n",
+ "duplicate_field.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "duplicate_field.emb",
+ struct.field[1].name.source_location,
+ "Duplicate name 'field2'",
+ ),
+ error.note(
+ "duplicate_field.emb",
+ struct.field[0].abbreviation.source_location,
+ "Original definition",
+ ),
+ ]
+ ],
+ errors,
+ )
- def test_duplicate_value_name_in_enum(self):
- ir = self._construct_ir("enum Foo:\n"
- " BAR = 1\n"
- " BAR = 1\n", "duplicate_enum.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- self.assertEqual([[
- error.error(
- "duplicate_enum.emb",
- ir.module[0].type[0].enumeration.value[1].name.source_location,
- "Duplicate name 'BAR'"),
- error.note(
- "duplicate_enum.emb",
- ir.module[0].type[0].enumeration.value[0].name.source_location,
- "Original definition")
- ]], errors)
+ def test_duplicate_value_name_in_enum(self):
+ ir = self._construct_ir(
+ "enum Foo:\n" " BAR = 1\n" " BAR = 1\n", "duplicate_enum.emb"
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "duplicate_enum.emb",
+ ir.module[0].type[0].enumeration.value[1].name.source_location,
+ "Duplicate name 'BAR'",
+ ),
+ error.note(
+ "duplicate_enum.emb",
+ ir.module[0].type[0].enumeration.value[0].name.source_location,
+ "Original definition",
+ ),
+ ]
+ ],
+ errors,
+ )
- def test_ambiguous_name(self):
- # struct UInt will be ambiguous with the external UInt in the prelude.
- ir = self._construct_ir("struct UInt:\n"
- " 0 [+4] Int:8[4] field\n"
- "struct Foo:\n"
- " 0 [+4] UInt bar\n", "ambiguous.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- # Find the UInt definition in the prelude.
- for type_ir in ir.module[1].type:
- if type_ir.name.name.text == "UInt":
- prelude_uint = type_ir
- break
- ambiguous_type_ir = ir.module[0].type[1].structure.field[0].type.atomic_type
- self.assertEqual([[
- error.error("ambiguous.emb",
- ambiguous_type_ir.reference.source_name[0].source_location,
- "Ambiguous name 'UInt'"), error.note(
- "", prelude_uint.name.source_location,
- "Possible resolution"),
- error.note("ambiguous.emb", ir.module[0].type[0].name.source_location,
- "Possible resolution")
- ]], errors)
+ def test_ambiguous_name(self):
+ # struct UInt will be ambiguous with the external UInt in the prelude.
+ ir = self._construct_ir(
+ "struct UInt:\n"
+ " 0 [+4] Int:8[4] field\n"
+ "struct Foo:\n"
+ " 0 [+4] UInt bar\n",
+ "ambiguous.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ # Find the UInt definition in the prelude.
+ for type_ir in ir.module[1].type:
+ if type_ir.name.name.text == "UInt":
+ prelude_uint = type_ir
+ break
+ ambiguous_type_ir = ir.module[0].type[1].structure.field[0].type.atomic_type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "ambiguous.emb",
+ ambiguous_type_ir.reference.source_name[0].source_location,
+ "Ambiguous name 'UInt'",
+ ),
+ error.note(
+ "", prelude_uint.name.source_location, "Possible resolution"
+ ),
+ error.note(
+ "ambiguous.emb",
+ ir.module[0].type[0].name.source_location,
+ "Possible resolution",
+ ),
+ ]
+ ],
+ errors,
+ )
- def test_missing_name(self):
- ir = self._construct_ir("struct Foo:\n"
- " 0 [+4] Bar field\n",
- "missing.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- missing_type_ir = ir.module[0].type[0].structure.field[0].type.atomic_type
- self.assertEqual([
- [error.error("missing.emb",
- missing_type_ir.reference.source_name[0].source_location,
- "No candidate for 'Bar'")]
- ], errors)
+ def test_missing_name(self):
+ ir = self._construct_ir("struct Foo:\n" " 0 [+4] Bar field\n", "missing.emb")
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ missing_type_ir = ir.module[0].type[0].structure.field[0].type.atomic_type
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "missing.emb",
+ missing_type_ir.reference.source_name[0].source_location,
+ "No candidate for 'Bar'",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_missing_leading_name(self):
- ir = self._construct_ir("struct Foo:\n"
- " 0 [+Num.FOUR] UInt field\n", "missing.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size
- self.assertEqual([
- [error.error(
+ def test_missing_leading_name(self):
+ ir = self._construct_ir(
+ "struct Foo:\n" " 0 [+Num.FOUR] UInt field\n", "missing.emb"
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "missing.emb",
+ missing_expr_ir.constant_reference.source_name[
+ 0
+ ].source_location,
+ "No candidate for 'Num'",
+ )
+ ]
+ ],
+ errors,
+ )
+
+ def test_missing_trailing_name(self):
+ ir = self._construct_ir(
+ "struct Foo:\n"
+ " 0 [+Num.FOUR] UInt field\n"
+ "enum Num:\n"
+ " THREE = 3\n",
"missing.emb",
- missing_expr_ir.constant_reference.source_name[0].source_location,
- "No candidate for 'Num'")]
- ], errors)
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "missing.emb",
+ missing_expr_ir.constant_reference.source_name[
+ 1
+ ].source_location,
+ "No candidate for 'FOUR'",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_missing_trailing_name(self):
- ir = self._construct_ir("struct Foo:\n"
- " 0 [+Num.FOUR] UInt field\n"
- "enum Num:\n"
- " THREE = 3\n", "missing.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size
- self.assertEqual([
- [error.error(
+ def test_missing_middle_name(self):
+ ir = self._construct_ir(
+ "struct Foo:\n"
+ " 0 [+Num.NaN.FOUR] UInt field\n"
+ "enum Num:\n"
+ " FOUR = 4\n",
"missing.emb",
- missing_expr_ir.constant_reference.source_name[1].source_location,
- "No candidate for 'FOUR'")]
- ], errors)
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "missing.emb",
+ missing_expr_ir.constant_reference.source_name[
+ 1
+ ].source_location,
+ "No candidate for 'NaN'",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_missing_middle_name(self):
- ir = self._construct_ir("struct Foo:\n"
- " 0 [+Num.NaN.FOUR] UInt field\n"
- "enum Num:\n"
- " FOUR = 4\n", "missing.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- missing_expr_ir = ir.module[0].type[0].structure.field[0].location.size
- self.assertEqual([
- [error.error(
- "missing.emb",
- missing_expr_ir.constant_reference.source_name[1].source_location,
- "No candidate for 'NaN'")]
- ], errors)
+ def test_inner_resolution(self):
+ ir = self._construct_ir(
+ "struct OuterStruct:\n"
+ "\n"
+ " struct InnerStruct2:\n"
+ " 0 [+1] InnerStruct.InnerEnum inner_enum\n"
+ "\n"
+ " struct InnerStruct:\n"
+ " enum InnerEnum:\n"
+ " ONE = 1\n"
+ "\n"
+ " 0 [+1] InnerEnum inner_enum\n"
+ "\n"
+ " 0 [+InnerStruct.InnerEnum.ONE] InnerStruct.InnerEnum inner_enum\n",
+ "nested.emb",
+ )
+ errors = symbol_resolver.resolve_symbols(ir)
+ self.assertFalse(errors)
+ outer_struct = ir.module[0].type[0]
+ inner_struct = outer_struct.subtype[1]
+ inner_struct_2 = outer_struct.subtype[0]
+ inner_enum = inner_struct.subtype[0]
+ self.assertEqual(
+ ["OuterStruct", "InnerStruct"],
+ list(inner_struct.name.canonical_name.object_path),
+ )
+ self.assertEqual(
+ ["OuterStruct", "InnerStruct", "InnerEnum"],
+ list(inner_enum.name.canonical_name.object_path),
+ )
+ self.assertEqual(
+ ["OuterStruct", "InnerStruct2"],
+ list(inner_struct_2.name.canonical_name.object_path),
+ )
+ outer_field = outer_struct.structure.field[0]
+ outer_field_end_ref = outer_field.location.size.constant_reference
+ self.assertEqual(
+ ["OuterStruct", "InnerStruct", "InnerEnum", "ONE"],
+ list(outer_field_end_ref.canonical_name.object_path),
+ )
+ self.assertEqual(
+ ["OuterStruct", "InnerStruct", "InnerEnum"],
+ list(outer_field.type.atomic_type.reference.canonical_name.object_path),
+ )
+ inner_field_2_type = inner_struct_2.structure.field[0].type.atomic_type
+ self.assertEqual(
+ ["OuterStruct", "InnerStruct", "InnerEnum"],
+ list(inner_field_2_type.reference.canonical_name.object_path),
+ )
- def test_inner_resolution(self):
- ir = self._construct_ir(
- "struct OuterStruct:\n"
- "\n"
- " struct InnerStruct2:\n"
- " 0 [+1] InnerStruct.InnerEnum inner_enum\n"
- "\n"
- " struct InnerStruct:\n"
- " enum InnerEnum:\n"
- " ONE = 1\n"
- "\n"
- " 0 [+1] InnerEnum inner_enum\n"
- "\n"
- " 0 [+InnerStruct.InnerEnum.ONE] InnerStruct.InnerEnum inner_enum\n",
- "nested.emb")
- errors = symbol_resolver.resolve_symbols(ir)
- self.assertFalse(errors)
- outer_struct = ir.module[0].type[0]
- inner_struct = outer_struct.subtype[1]
- inner_struct_2 = outer_struct.subtype[0]
- inner_enum = inner_struct.subtype[0]
- self.assertEqual(["OuterStruct", "InnerStruct"],
- list(inner_struct.name.canonical_name.object_path))
- self.assertEqual(["OuterStruct", "InnerStruct", "InnerEnum"],
- list(inner_enum.name.canonical_name.object_path))
- self.assertEqual(["OuterStruct", "InnerStruct2"],
- list(inner_struct_2.name.canonical_name.object_path))
- outer_field = outer_struct.structure.field[0]
- outer_field_end_ref = outer_field.location.size.constant_reference
- self.assertEqual(
- ["OuterStruct", "InnerStruct", "InnerEnum", "ONE"], list(
- outer_field_end_ref.canonical_name.object_path))
- self.assertEqual(
- ["OuterStruct", "InnerStruct", "InnerEnum"],
- list(outer_field.type.atomic_type.reference.canonical_name.object_path))
- inner_field_2_type = inner_struct_2.structure.field[0].type.atomic_type
- self.assertEqual(
- ["OuterStruct", "InnerStruct", "InnerEnum"
- ], list(inner_field_2_type.reference.canonical_name.object_path))
+ def test_resolution_against_anonymous_bits(self):
+ ir = self._construct_ir(
+ "struct Struct:\n"
+ " 0 [+1] bits:\n"
+ " 7 [+1] Flag last_packet\n"
+ " 5 [+2] enum inline_inner_enum:\n"
+ " AA = 0\n"
+ " BB = 1\n"
+ " CC = 2\n"
+ " DD = 3\n"
+ " 0 [+5] UInt header_size (h)\n"
+ " 0 [+h] UInt:8[] header_bytes\n"
+ "\n"
+ "struct Struct2:\n"
+ " 0 [+1] Struct.InlineInnerEnum value\n",
+ "anonymity.emb",
+ )
+ errors = symbol_resolver.resolve_symbols(ir)
+ self.assertFalse(errors)
+ struct1 = ir.module[0].type[0]
+ struct1_bits_field = struct1.structure.field[0]
+ struct1_bits_field_type = struct1_bits_field.type.atomic_type.reference
+ struct1_byte_field = struct1.structure.field[4]
+ inner_bits = struct1.subtype[0]
+ inner_enum = struct1.subtype[1]
+ self.assertTrue(inner_bits.HasField("structure"))
+ self.assertTrue(inner_enum.HasField("enumeration"))
+ self.assertTrue(inner_bits.name.is_anonymous)
+ self.assertFalse(inner_enum.name.is_anonymous)
+ self.assertEqual(
+ ["Struct", "InlineInnerEnum"],
+ list(inner_enum.name.canonical_name.object_path),
+ )
+ self.assertEqual(
+ ["Struct", "InlineInnerEnum", "AA"],
+ list(inner_enum.enumeration.value[0].name.canonical_name.object_path),
+ )
+ self.assertEqual(
+ list(inner_bits.name.canonical_name.object_path),
+ list(struct1_bits_field_type.canonical_name.object_path),
+ )
+ self.assertEqual(2, len(inner_bits.name.canonical_name.object_path))
+ self.assertEqual(
+ ["Struct", "header_size"],
+ list(
+ struct1_byte_field.location.size.field_reference.path[
+ 0
+ ].canonical_name.object_path
+ ),
+ )
- def test_resolution_against_anonymous_bits(self):
- ir = self._construct_ir("struct Struct:\n"
- " 0 [+1] bits:\n"
- " 7 [+1] Flag last_packet\n"
- " 5 [+2] enum inline_inner_enum:\n"
- " AA = 0\n"
- " BB = 1\n"
- " CC = 2\n"
- " DD = 3\n"
- " 0 [+5] UInt header_size (h)\n"
- " 0 [+h] UInt:8[] header_bytes\n"
- "\n"
- "struct Struct2:\n"
- " 0 [+1] Struct.InlineInnerEnum value\n",
- "anonymity.emb")
- errors = symbol_resolver.resolve_symbols(ir)
- self.assertFalse(errors)
- struct1 = ir.module[0].type[0]
- struct1_bits_field = struct1.structure.field[0]
- struct1_bits_field_type = struct1_bits_field.type.atomic_type.reference
- struct1_byte_field = struct1.structure.field[4]
- inner_bits = struct1.subtype[0]
- inner_enum = struct1.subtype[1]
- self.assertTrue(inner_bits.HasField("structure"))
- self.assertTrue(inner_enum.HasField("enumeration"))
- self.assertTrue(inner_bits.name.is_anonymous)
- self.assertFalse(inner_enum.name.is_anonymous)
- self.assertEqual(["Struct", "InlineInnerEnum"],
- list(inner_enum.name.canonical_name.object_path))
- self.assertEqual(
- ["Struct", "InlineInnerEnum", "AA"],
- list(inner_enum.enumeration.value[0].name.canonical_name.object_path))
- self.assertEqual(
- list(inner_bits.name.canonical_name.object_path),
- list(struct1_bits_field_type.canonical_name.object_path))
- self.assertEqual(2, len(inner_bits.name.canonical_name.object_path))
- self.assertEqual(
- ["Struct", "header_size"],
- list(struct1_byte_field.location.size.field_reference.path[0].
- canonical_name.object_path))
-
- def test_duplicate_name_in_different_inline_bits(self):
- ir = self._construct_ir(
- "struct Struct:\n"
- " 0 [+1] bits:\n"
- " 7 [+1] Flag a\n"
- " 1 [+1] bits:\n"
- " 0 [+1] Flag a\n", "duplicate_in_anon.emb")
- errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
- supertype = ir.module[0].type[0]
- self.assertEqual([[
- error.error(
+ def test_duplicate_name_in_different_inline_bits(self):
+ ir = self._construct_ir(
+ "struct Struct:\n"
+ " 0 [+1] bits:\n"
+ " 7 [+1] Flag a\n"
+ " 1 [+1] bits:\n"
+ " 0 [+1] Flag a\n",
"duplicate_in_anon.emb",
- supertype.structure.field[3].name.source_location,
- "Duplicate name 'a'"),
- error.note(
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_symbols(ir))
+ supertype = ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "duplicate_in_anon.emb",
+ supertype.structure.field[3].name.source_location,
+ "Duplicate name 'a'",
+ ),
+ error.note(
+ "duplicate_in_anon.emb",
+ supertype.structure.field[1].name.source_location,
+ "Original definition",
+ ),
+ ]
+ ],
+ errors,
+ )
+
+ def test_duplicate_name_in_same_inline_bits(self):
+ ir = self._construct_ir(
+ "struct Struct:\n"
+ " 0 [+1] bits:\n"
+ " 7 [+1] Flag a\n"
+ " 0 [+1] Flag a\n",
"duplicate_in_anon.emb",
- supertype.structure.field[1].name.source_location,
- "Original definition")
- ]], errors)
+ )
+ errors = symbol_resolver.resolve_symbols(ir)
+ supertype = ir.module[0].type[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "duplicate_in_anon.emb",
+ supertype.structure.field[2].name.source_location,
+ "Duplicate name 'a'",
+ ),
+ error.note(
+ "duplicate_in_anon.emb",
+ supertype.structure.field[1].name.source_location,
+ "Original definition",
+ ),
+ ]
+ ],
+ error.filter_errors(errors),
+ )
- def test_duplicate_name_in_same_inline_bits(self):
- ir = self._construct_ir(
- "struct Struct:\n"
- " 0 [+1] bits:\n"
- " 7 [+1] Flag a\n"
- " 0 [+1] Flag a\n", "duplicate_in_anon.emb")
- errors = symbol_resolver.resolve_symbols(ir)
- supertype = ir.module[0].type[0]
- self.assertEqual([[
- error.error(
- "duplicate_in_anon.emb",
- supertype.structure.field[2].name.source_location,
- "Duplicate name 'a'"),
- error.note(
- "duplicate_in_anon.emb",
- supertype.structure.field[1].name.source_location,
- "Original definition")
- ]], error.filter_errors(errors))
+ def test_import_type_resolution(self):
+ importer = 'import "ed.emb" as ed\n' "struct Ff:\n" " 0 [+1] ed.Gg gg\n"
+ imported = "struct Gg:\n" " 0 [+1] UInt qq\n"
+ ir = self._construct_ir_multiple(
+ {"ed.emb": imported, "er.emb": importer}, "er.emb"
+ )
+ errors = symbol_resolver.resolve_symbols(ir)
+ self.assertEqual([], errors)
- def test_import_type_resolution(self):
- importer = ('import "ed.emb" as ed\n'
- "struct Ff:\n"
- " 0 [+1] ed.Gg gg\n")
- imported = ("struct Gg:\n"
- " 0 [+1] UInt qq\n")
- ir = self._construct_ir_multiple({"ed.emb": imported, "er.emb": importer},
- "er.emb")
- errors = symbol_resolver.resolve_symbols(ir)
- self.assertEqual([], errors)
+ def test_duplicate_import_name(self):
+ importer = (
+ 'import "ed.emb" as ed\n'
+ 'import "ed.emb" as ed\n'
+ "struct Ff:\n"
+ " 0 [+1] ed.Gg gg\n"
+ )
+ imported = "struct Gg:\n" " 0 [+1] UInt qq\n"
+ ir = self._construct_ir_multiple(
+ {"ed.emb": imported, "er.emb": importer}, "er.emb"
+ )
+ errors = symbol_resolver.resolve_symbols(ir)
+ # Note: the error is on import[2] duplicating import[1] because the implicit
+ # prelude import is import[0].
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "er.emb",
+ ir.module[0].foreign_import[2].local_name.source_location,
+ "Duplicate name 'ed'",
+ ),
+ error.note(
+ "er.emb",
+ ir.module[0].foreign_import[1].local_name.source_location,
+ "Original definition",
+ ),
+ ]
+ ],
+ errors,
+ )
- def test_duplicate_import_name(self):
- importer = ('import "ed.emb" as ed\n'
- 'import "ed.emb" as ed\n'
- "struct Ff:\n"
- " 0 [+1] ed.Gg gg\n")
- imported = ("struct Gg:\n"
- " 0 [+1] UInt qq\n")
- ir = self._construct_ir_multiple({"ed.emb": imported, "er.emb": importer},
- "er.emb")
- errors = symbol_resolver.resolve_symbols(ir)
- # Note: the error is on import[2] duplicating import[1] because the implicit
- # prelude import is import[0].
- self.assertEqual([
- [error.error("er.emb",
- ir.module[0].foreign_import[2].local_name.source_location,
- "Duplicate name 'ed'"),
- error.note("er.emb",
- ir.module[0].foreign_import[1].local_name.source_location,
- "Original definition")]
- ], errors)
+ def test_import_enum_resolution(self):
+ importer = (
+ 'import "ed.emb" as ed\n'
+ "struct Ff:\n"
+ " if ed.Gg.GG == ed.Gg.GG:\n"
+ " 0 [+1] UInt gg\n"
+ )
+ imported = "enum Gg:\n" " GG = 0\n"
+ ir = self._construct_ir_multiple(
+ {"ed.emb": imported, "er.emb": importer}, "er.emb"
+ )
+ errors = symbol_resolver.resolve_symbols(ir)
+ self.assertEqual([], errors)
- def test_import_enum_resolution(self):
- importer = ('import "ed.emb" as ed\n'
- "struct Ff:\n"
- " if ed.Gg.GG == ed.Gg.GG:\n"
- " 0 [+1] UInt gg\n")
- imported = ("enum Gg:\n"
- " GG = 0\n")
- ir = self._construct_ir_multiple({"ed.emb": imported, "er.emb": importer},
- "er.emb")
- errors = symbol_resolver.resolve_symbols(ir)
- self.assertEqual([], errors)
+ def test_that_double_import_names_are_syntactically_invalid(self):
+ # There are currently no checks in resolve_symbols that it is not possible
+ # to get to symbols imported by another module, because it is syntactically
+ # invalid. This may change in the future, in which case this test should be
+ # fixed by adding an explicit check to resolve_symbols and checking the
+ # error message here.
+ importer = 'import "ed.emb" as ed\n' "struct Ff:\n" " 0 [+1] ed.ed2.Gg gg\n"
+ imported = 'import "ed2.emb" as ed2\n'
+ imported2 = "struct Gg:\n" " 0 [+1] UInt qq\n"
+ unused_ir, unused_debug_info, errors = glue.parse_emboss_file(
+ "er.emb",
+ test_util.dict_file_reader(
+ {"ed.emb": imported, "ed2.emb": imported2, "er.emb": importer}
+ ),
+ stop_before_step="resolve_symbols",
+ )
+ assert errors
- def test_that_double_import_names_are_syntactically_invalid(self):
- # There are currently no checks in resolve_symbols that it is not possible
- # to get to symbols imported by another module, because it is syntactically
- # invalid. This may change in the future, in which case this test should be
- # fixed by adding an explicit check to resolve_symbols and checking the
- # error message here.
- importer = ('import "ed.emb" as ed\n'
- "struct Ff:\n"
- " 0 [+1] ed.ed2.Gg gg\n")
- imported = 'import "ed2.emb" as ed2\n'
- imported2 = ("struct Gg:\n"
- " 0 [+1] UInt qq\n")
- unused_ir, unused_debug_info, errors = glue.parse_emboss_file(
- "er.emb",
- test_util.dict_file_reader({"ed.emb": imported,
- "ed2.emb": imported2,
- "er.emb": importer}),
- stop_before_step="resolve_symbols")
- assert errors
+ def test_no_error_when_inline_name_aliases_outer_name(self):
+ # The inline enum's complete type should be Foo.Foo. During parsing, the
+ # name is set to just "Foo", but symbol resolution should a) select the
+ # correct Foo, and b) not complain that multiple Foos could match.
+ ir = self._construct_ir(
+ "struct Foo:\n" " 0 [+1] enum foo:\n" " BAR = 0\n"
+ )
+ errors = symbol_resolver.resolve_symbols(ir)
+ self.assertEqual([], errors)
+ field = ir.module[0].type[0].structure.field[0]
+ self.assertEqual(
+ ["Foo", "Foo"],
+ list(field.type.atomic_type.reference.canonical_name.object_path),
+ )
- def test_no_error_when_inline_name_aliases_outer_name(self):
- # The inline enum's complete type should be Foo.Foo. During parsing, the
- # name is set to just "Foo", but symbol resolution should a) select the
- # correct Foo, and b) not complain that multiple Foos could match.
- ir = self._construct_ir(
- "struct Foo:\n"
- " 0 [+1] enum foo:\n"
- " BAR = 0\n")
- errors = symbol_resolver.resolve_symbols(ir)
- self.assertEqual([], errors)
- field = ir.module[0].type[0].structure.field[0]
- self.assertEqual(
- ["Foo", "Foo"],
- list(field.type.atomic_type.reference.canonical_name.object_path))
-
- def test_no_error_when_inline_name_in_anonymous_bits_aliases_outer_name(self):
- # There is an extra layer of complexity when an inline type appears inside
- # of an inline bits.
- ir = self._construct_ir(
- "struct Foo:\n"
- " 0 [+1] bits:\n"
- " 0 [+4] enum foo:\n"
- " BAR = 0\n")
- errors = symbol_resolver.resolve_symbols(ir)
- self.assertEqual([], error.filter_errors(errors))
- field = ir.module[0].type[0].subtype[0].structure.field[0]
- self.assertEqual(
- ["Foo", "Foo"],
- list(field.type.atomic_type.reference.canonical_name.object_path))
+ def test_no_error_when_inline_name_in_anonymous_bits_aliases_outer_name(self):
+ # There is an extra layer of complexity when an inline type appears inside
+ # of an inline bits.
+ ir = self._construct_ir(
+ "struct Foo:\n"
+ " 0 [+1] bits:\n"
+ " 0 [+4] enum foo:\n"
+ " BAR = 0\n"
+ )
+ errors = symbol_resolver.resolve_symbols(ir)
+ self.assertEqual([], error.filter_errors(errors))
+ field = ir.module[0].type[0].subtype[0].structure.field[0]
+ self.assertEqual(
+ ["Foo", "Foo"],
+ list(field.type.atomic_type.reference.canonical_name.object_path),
+ )
class ResolveFieldReferencesTest(unittest.TestCase):
- """Tests for symbol_resolver.resolve_field_references()."""
+ """Tests for symbol_resolver.resolve_field_references()."""
- def _construct_ir_multiple(self, file_dict, primary_emb_name):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- primary_emb_name,
- test_util.dict_file_reader(file_dict),
- stop_before_step="resolve_field_references")
- assert not errors
- return ir
+ def _construct_ir_multiple(self, file_dict, primary_emb_name):
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ primary_emb_name,
+ test_util.dict_file_reader(file_dict),
+ stop_before_step="resolve_field_references",
+ )
+ assert not errors
+ return ir
- def _construct_ir(self, emb_text, name="happy.emb"):
- return self._construct_ir_multiple({name: emb_text}, name)
+ def _construct_ir(self, emb_text, name="happy.emb"):
+ return self._construct_ir_multiple({name: emb_text}, name)
- def test_subfield_resolution(self):
- ir = self._construct_ir(
- "struct Ff:\n"
- " 0 [+1] Gg gg\n"
- " 1 [+gg.qq] UInt:8[] data\n"
- "struct Gg:\n"
- " 0 [+1] UInt qq\n", "subfield.emb")
- errors = symbol_resolver.resolve_field_references(ir)
- self.assertFalse(errors)
- ff = ir.module[0].type[0]
- location_end_path = ff.structure.field[1].location.size.field_reference.path
- self.assertEqual(["Ff", "gg"],
- list(location_end_path[0].canonical_name.object_path))
- self.assertEqual(["Gg", "qq"],
- list(location_end_path[1].canonical_name.object_path))
+ def test_subfield_resolution(self):
+ ir = self._construct_ir(
+ "struct Ff:\n"
+ " 0 [+1] Gg gg\n"
+ " 1 [+gg.qq] UInt:8[] data\n"
+ "struct Gg:\n"
+ " 0 [+1] UInt qq\n",
+ "subfield.emb",
+ )
+ errors = symbol_resolver.resolve_field_references(ir)
+ self.assertFalse(errors)
+ ff = ir.module[0].type[0]
+ location_end_path = ff.structure.field[1].location.size.field_reference.path
+ self.assertEqual(
+ ["Ff", "gg"], list(location_end_path[0].canonical_name.object_path)
+ )
+ self.assertEqual(
+ ["Gg", "qq"], list(location_end_path[1].canonical_name.object_path)
+ )
- def test_aliased_subfield_resolution(self):
- ir = self._construct_ir(
- "struct Ff:\n"
- " 0 [+1] Gg real_gg\n"
- " 1 [+gg.qq] UInt:8[] data\n"
- " let gg = real_gg\n"
- "struct Gg:\n"
- " 0 [+1] UInt real_qq\n"
- " let qq = real_qq", "subfield.emb")
- errors = symbol_resolver.resolve_field_references(ir)
- self.assertFalse(errors)
- ff = ir.module[0].type[0]
- location_end_path = ff.structure.field[1].location.size.field_reference.path
- self.assertEqual(["Ff", "gg"],
- list(location_end_path[0].canonical_name.object_path))
- self.assertEqual(["Gg", "qq"],
- list(location_end_path[1].canonical_name.object_path))
+ def test_aliased_subfield_resolution(self):
+ ir = self._construct_ir(
+ "struct Ff:\n"
+ " 0 [+1] Gg real_gg\n"
+ " 1 [+gg.qq] UInt:8[] data\n"
+ " let gg = real_gg\n"
+ "struct Gg:\n"
+ " 0 [+1] UInt real_qq\n"
+ " let qq = real_qq",
+ "subfield.emb",
+ )
+ errors = symbol_resolver.resolve_field_references(ir)
+ self.assertFalse(errors)
+ ff = ir.module[0].type[0]
+ location_end_path = ff.structure.field[1].location.size.field_reference.path
+ self.assertEqual(
+ ["Ff", "gg"], list(location_end_path[0].canonical_name.object_path)
+ )
+ self.assertEqual(
+ ["Gg", "qq"], list(location_end_path[1].canonical_name.object_path)
+ )
- def test_aliased_aliased_subfield_resolution(self):
- ir = self._construct_ir(
- "struct Ff:\n"
- " 0 [+1] Gg really_real_gg\n"
- " 1 [+gg.qq] UInt:8[] data\n"
- " let gg = real_gg\n"
- " let real_gg = really_real_gg\n"
- "struct Gg:\n"
- " 0 [+1] UInt qq\n", "subfield.emb")
- errors = symbol_resolver.resolve_field_references(ir)
- self.assertFalse(errors)
- ff = ir.module[0].type[0]
- location_end_path = ff.structure.field[1].location.size.field_reference.path
- self.assertEqual(["Ff", "gg"],
- list(location_end_path[0].canonical_name.object_path))
- self.assertEqual(["Gg", "qq"],
- list(location_end_path[1].canonical_name.object_path))
+ def test_aliased_aliased_subfield_resolution(self):
+ ir = self._construct_ir(
+ "struct Ff:\n"
+ " 0 [+1] Gg really_real_gg\n"
+ " 1 [+gg.qq] UInt:8[] data\n"
+ " let gg = real_gg\n"
+ " let real_gg = really_real_gg\n"
+ "struct Gg:\n"
+ " 0 [+1] UInt qq\n",
+ "subfield.emb",
+ )
+ errors = symbol_resolver.resolve_field_references(ir)
+ self.assertFalse(errors)
+ ff = ir.module[0].type[0]
+ location_end_path = ff.structure.field[1].location.size.field_reference.path
+ self.assertEqual(
+ ["Ff", "gg"], list(location_end_path[0].canonical_name.object_path)
+ )
+ self.assertEqual(
+ ["Gg", "qq"], list(location_end_path[1].canonical_name.object_path)
+ )
- def test_subfield_resolution_fails(self):
- ir = self._construct_ir(
- "struct Ff:\n"
- " 0 [+1] Gg gg\n"
- " 1 [+gg.rr] UInt:8[] data\n"
- "struct Gg:\n"
- " 0 [+1] UInt qq\n", "subfield.emb")
- errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
- self.assertEqual([
- [error.error("subfield.emb", ir.module[0].type[0].structure.field[
- 1].location.size.field_reference.path[1].source_name[
- 0].source_location, "No candidate for 'rr'")]
- ], errors)
+ def test_subfield_resolution_fails(self):
+ ir = self._construct_ir(
+ "struct Ff:\n"
+ " 0 [+1] Gg gg\n"
+ " 1 [+gg.rr] UInt:8[] data\n"
+ "struct Gg:\n"
+ " 0 [+1] UInt qq\n",
+ "subfield.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "subfield.emb",
+ ir.module[0]
+ .type[0]
+ .structure.field[1]
+ .location.size.field_reference.path[1]
+ .source_name[0]
+ .source_location,
+ "No candidate for 'rr'",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_subfield_resolution_failure_shortcuts_further_resolution(self):
- ir = self._construct_ir(
- "struct Ff:\n"
- " 0 [+1] Gg gg\n"
- " 1 [+gg.rr.qq] UInt:8[] data\n"
- "struct Gg:\n"
- " 0 [+1] UInt qq\n", "subfield.emb")
- errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
- self.assertEqual([
- [error.error("subfield.emb", ir.module[0].type[0].structure.field[
- 1].location.size.field_reference.path[1].source_name[
- 0].source_location, "No candidate for 'rr'")]
- ], errors)
+ def test_subfield_resolution_failure_shortcuts_further_resolution(self):
+ ir = self._construct_ir(
+ "struct Ff:\n"
+ " 0 [+1] Gg gg\n"
+ " 1 [+gg.rr.qq] UInt:8[] data\n"
+ "struct Gg:\n"
+ " 0 [+1] UInt qq\n",
+ "subfield.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "subfield.emb",
+ ir.module[0]
+ .type[0]
+ .structure.field[1]
+ .location.size.field_reference.path[1]
+ .source_name[0]
+ .source_location,
+ "No candidate for 'rr'",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_subfield_resolution_failure_with_aliased_name(self):
- ir = self._construct_ir(
- "struct Ff:\n"
- " 0 [+1] Gg gg\n"
- " 1 [+gg.gg] UInt:8[] data\n"
- "struct Gg:\n"
- " 0 [+1] UInt qq\n", "subfield.emb")
- errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
- self.assertEqual([
- [error.error("subfield.emb", ir.module[0].type[0].structure.field[
- 1].location.size.field_reference.path[1].source_name[
- 0].source_location, "No candidate for 'gg'")]
- ], errors)
+ def test_subfield_resolution_failure_with_aliased_name(self):
+ ir = self._construct_ir(
+ "struct Ff:\n"
+ " 0 [+1] Gg gg\n"
+ " 1 [+gg.gg] UInt:8[] data\n"
+ "struct Gg:\n"
+ " 0 [+1] UInt qq\n",
+ "subfield.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "subfield.emb",
+ ir.module[0]
+ .type[0]
+ .structure.field[1]
+ .location.size.field_reference.path[1]
+ .source_name[0]
+ .source_location,
+ "No candidate for 'gg'",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_subfield_resolution_failure_with_array(self):
- ir = self._construct_ir(
- "struct Ff:\n"
- " 0 [+1] Gg[1] gg\n"
- " 1 [+gg.qq] UInt:8[] data\n"
- "struct Gg:\n"
- " 0 [+1] UInt qq\n", "subfield.emb")
- errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
- self.assertEqual([
- [error.error("subfield.emb", ir.module[0].type[0].structure.field[
- 1].location.size.field_reference.path[0].source_name[
- 0].source_location, "Cannot access member of array 'gg'")]
- ], errors)
+ def test_subfield_resolution_failure_with_array(self):
+ ir = self._construct_ir(
+ "struct Ff:\n"
+ " 0 [+1] Gg[1] gg\n"
+ " 1 [+gg.qq] UInt:8[] data\n"
+ "struct Gg:\n"
+ " 0 [+1] UInt qq\n",
+ "subfield.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "subfield.emb",
+ ir.module[0]
+ .type[0]
+ .structure.field[1]
+ .location.size.field_reference.path[0]
+ .source_name[0]
+ .source_location,
+ "Cannot access member of array 'gg'",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_subfield_resolution_failure_with_int(self):
- ir = self._construct_ir(
- "struct Ff:\n"
- " 0 [+1] UInt gg_source\n"
- " 1 [+gg.qq] UInt:8[] data\n"
- " let gg = gg_source + 1\n",
- "subfield.emb")
- errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
- error_field = ir.module[0].type[0].structure.field[1]
- error_reference = error_field.location.size.field_reference
- error_location = error_reference.path[0].source_name[0].source_location
- self.assertEqual([
- [error.error("subfield.emb", error_location,
- "Cannot access member of noncomposite field 'gg'")]
- ], errors)
+ def test_subfield_resolution_failure_with_int(self):
+ ir = self._construct_ir(
+ "struct Ff:\n"
+ " 0 [+1] UInt gg_source\n"
+ " 1 [+gg.qq] UInt:8[] data\n"
+ " let gg = gg_source + 1\n",
+ "subfield.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
+ error_field = ir.module[0].type[0].structure.field[1]
+ error_reference = error_field.location.size.field_reference
+ error_location = error_reference.path[0].source_name[0].source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "subfield.emb",
+ error_location,
+ "Cannot access member of noncomposite field 'gg'",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_subfield_resolution_failure_with_int_no_cascade(self):
- ir = self._construct_ir(
- "struct Ff:\n"
- " 0 [+1] UInt gg_source\n"
- " 1 [+qqx] UInt:8[] data\n"
- " let gg = gg_source + 1\n"
- " let yy = gg.no_field\n"
- " let qqx = yy.x\n"
- " let qqy = yy.y\n",
- "subfield.emb")
- errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
- error_field = ir.module[0].type[0].structure.field[3]
- error_reference = error_field.read_transform.field_reference
- error_location = error_reference.path[0].source_name[0].source_location
- self.assertEqual([
- [error.error("subfield.emb", error_location,
- "Cannot access member of noncomposite field 'gg'")]
- ], errors)
+ def test_subfield_resolution_failure_with_int_no_cascade(self):
+ ir = self._construct_ir(
+ "struct Ff:\n"
+ " 0 [+1] UInt gg_source\n"
+ " 1 [+qqx] UInt:8[] data\n"
+ " let gg = gg_source + 1\n"
+ " let yy = gg.no_field\n"
+ " let qqx = yy.x\n"
+ " let qqy = yy.y\n",
+ "subfield.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
+ error_field = ir.module[0].type[0].structure.field[3]
+ error_reference = error_field.read_transform.field_reference
+ error_location = error_reference.path[0].source_name[0].source_location
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "subfield.emb",
+ error_location,
+ "Cannot access member of noncomposite field 'gg'",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_subfield_resolution_failure_with_abbreviation(self):
- ir = self._construct_ir(
- "struct Ff:\n"
- " 0 [+1] Gg gg\n"
- " 1 [+gg.q] UInt:8[] data\n"
- "struct Gg:\n"
- " 0 [+1] UInt qq (q)\n", "subfield.emb")
- errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
- self.assertEqual([
- # TODO(bolms): Make the error message clearer, in this case.
- [error.error("subfield.emb", ir.module[0].type[0].structure.field[
- 1].location.size.field_reference.path[1].source_name[
- 0].source_location, "No candidate for 'q'")]
- ], errors)
+ def test_subfield_resolution_failure_with_abbreviation(self):
+ ir = self._construct_ir(
+ "struct Ff:\n"
+ " 0 [+1] Gg gg\n"
+ " 1 [+gg.q] UInt:8[] data\n"
+ "struct Gg:\n"
+ " 0 [+1] UInt qq (q)\n",
+ "subfield.emb",
+ )
+ errors = error.filter_errors(symbol_resolver.resolve_field_references(ir))
+ self.assertEqual(
+ [
+ # TODO(bolms): Make the error message clearer, in this case.
+ [
+ error.error(
+ "subfield.emb",
+ ir.module[0]
+ .type[0]
+ .structure.field[1]
+ .location.size.field_reference.path[1]
+ .source_name[0]
+ .source_location,
+ "No candidate for 'q'",
+ )
+ ]
+ ],
+ errors,
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/synthetics.py b/compiler/front_end/synthetics.py
index 73be294..1b331a3 100644
--- a/compiler/front_end/synthetics.py
+++ b/compiler/front_end/synthetics.py
@@ -24,27 +24,28 @@
def _mark_as_synthetic(proto):
- """Marks all source_locations in proto with is_synthetic=True."""
- if not isinstance(proto, ir_data.Message):
- return
- if hasattr(proto, "source_location"):
- ir_data_utils.builder(proto).source_location.is_synthetic = True
- for spec, value in ir_data_utils.get_set_fields(proto):
- if spec.name != "source_location" and spec.is_dataclass:
- if spec.is_sequence:
- for i in value:
- _mark_as_synthetic(i)
- else:
- _mark_as_synthetic(value)
+ """Marks all source_locations in proto with is_synthetic=True."""
+ if not isinstance(proto, ir_data.Message):
+ return
+ if hasattr(proto, "source_location"):
+ ir_data_utils.builder(proto).source_location.is_synthetic = True
+ for spec, value in ir_data_utils.get_set_fields(proto):
+ if spec.name != "source_location" and spec.is_dataclass:
+ if spec.is_sequence:
+ for i in value:
+ _mark_as_synthetic(i)
+ else:
+ _mark_as_synthetic(value)
def _skip_text_output_attribute():
- """Returns the IR for a [text_output: "Skip"] attribute."""
- result = ir_data.Attribute(
- name=ir_data.Word(text=attributes.TEXT_OUTPUT),
- value=ir_data.AttributeValue(string_constant=ir_data.String(text="Skip")))
- _mark_as_synthetic(result)
- return result
+ """Returns the IR for a [text_output: "Skip"] attribute."""
+ result = ir_data.Attribute(
+ name=ir_data.Word(text=attributes.TEXT_OUTPUT),
+ value=ir_data.AttributeValue(string_constant=ir_data.String(text="Skip")),
+ )
+ _mark_as_synthetic(result)
+ return result
# The existence condition for an alias for an anonymous bits' field is the union
@@ -52,122 +53,133 @@
# for the field within. The 'x' and 'x.y' are placeholders here; they'll be
# overwritten in _add_anonymous_aliases.
_ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON = expression_parser.parse(
- "$present(x) && $present(x.y)")
+ "$present(x) && $present(x.y)"
+)
def _add_anonymous_aliases(structure, type_definition):
- """Adds synthetic alias fields for all fields in anonymous fields.
+ """Adds synthetic alias fields for all fields in anonymous fields.
- This essentially completes the rewrite of this:
+ This essentially completes the rewrite of this:
- struct Foo:
- 0 [+4] bits:
- 0 [+1] Flag low
- 31 [+1] Flag high
+ struct Foo:
+ 0 [+4] bits:
+ 0 [+1] Flag low
+ 31 [+1] Flag high
- Into this:
+ Into this:
- struct Foo:
- bits EmbossReservedAnonymous0:
- [text_output: "Skip"]
- 0 [+1] Flag low
- 31 [+1] Flag high
- 0 [+4] EmbossReservedAnonymous0 emboss_reserved_anonymous_1
- let low = emboss_reserved_anonymous_1.low
- let high = emboss_reserved_anonymous_1.high
+ struct Foo:
+ bits EmbossReservedAnonymous0:
+ [text_output: "Skip"]
+ 0 [+1] Flag low
+ 31 [+1] Flag high
+ 0 [+4] EmbossReservedAnonymous0 emboss_reserved_anonymous_1
+ let low = emboss_reserved_anonymous_1.low
+ let high = emboss_reserved_anonymous_1.high
- Note that this pass runs very, very early -- even before symbols have been
- resolved -- so very little in ir_util will work at this point.
+ Note that this pass runs very, very early -- even before symbols have been
+ resolved -- so very little in ir_util will work at this point.
- Arguments:
- structure: The ir_data.Structure on which to synthesize fields.
- type_definition: The ir_data.TypeDefinition containing structure.
+ Arguments:
+ structure: The ir_data.Structure on which to synthesize fields.
+ type_definition: The ir_data.TypeDefinition containing structure.
- Returns:
- None
- """
- new_fields = []
- for field in structure.field:
- new_fields.append(field)
- if not field.name.is_anonymous:
- continue
- field.attribute.extend([_skip_text_output_attribute()])
- for subtype in type_definition.subtype:
- if (subtype.name.name.text ==
- field.type.atomic_type.reference.source_name[-1].text):
- field_type = subtype
- break
- else:
- assert False, ("Unable to find corresponding type {} for anonymous field "
- "in {}.".format(
- field.type.atomic_type.reference, type_definition))
- anonymous_reference = ir_data.Reference(source_name=[field.name.name])
- anonymous_field_reference = ir_data.FieldReference(
- path=[anonymous_reference])
- for subfield in field_type.structure.field:
- alias_field_reference = ir_data.FieldReference(
- path=[
- anonymous_reference,
- ir_data.Reference(source_name=[subfield.name.name]),
- ]
- )
- new_existence_condition = ir_data_utils.copy(_ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON)
- existence_clauses = ir_data_utils.builder(new_existence_condition).function.args
- existence_clauses[0].function.args[0].field_reference.CopyFrom(
- anonymous_field_reference)
- existence_clauses[1].function.args[0].field_reference.CopyFrom(
- alias_field_reference)
- new_read_transform = ir_data.Expression(
- field_reference=ir_data_utils.copy(alias_field_reference))
- # This treats *most* of the alias field as synthetic, but not its name(s):
- # leaving the name(s) as "real" means that symbol collisions with the
- # surrounding structure will be properly reported to the user.
- _mark_as_synthetic(new_existence_condition)
- _mark_as_synthetic(new_read_transform)
- new_alias = ir_data.Field(
- read_transform=new_read_transform,
- existence_condition=new_existence_condition,
- name=ir_data_utils.copy(subfield.name))
- if subfield.HasField("abbreviation"):
- ir_data_utils.builder(new_alias).abbreviation.CopyFrom(subfield.abbreviation)
- _mark_as_synthetic(new_alias.existence_condition)
- _mark_as_synthetic(new_alias.read_transform)
- new_fields.append(new_alias)
- # Since the alias field's name(s) are "real," it is important to mark the
- # original field's name(s) as synthetic, to avoid duplicate error
- # messages.
- _mark_as_synthetic(subfield.name)
- if subfield.HasField("abbreviation"):
- _mark_as_synthetic(subfield.abbreviation)
- del structure.field[:]
- structure.field.extend(new_fields)
+ Returns:
+ None
+ """
+ new_fields = []
+ for field in structure.field:
+ new_fields.append(field)
+ if not field.name.is_anonymous:
+ continue
+ field.attribute.extend([_skip_text_output_attribute()])
+ for subtype in type_definition.subtype:
+ if (
+ subtype.name.name.text
+ == field.type.atomic_type.reference.source_name[-1].text
+ ):
+ field_type = subtype
+ break
+ else:
+ assert False, (
+ "Unable to find corresponding type {} for anonymous field "
+ "in {}.".format(field.type.atomic_type.reference, type_definition)
+ )
+ anonymous_reference = ir_data.Reference(source_name=[field.name.name])
+ anonymous_field_reference = ir_data.FieldReference(path=[anonymous_reference])
+ for subfield in field_type.structure.field:
+ alias_field_reference = ir_data.FieldReference(
+ path=[
+ anonymous_reference,
+ ir_data.Reference(source_name=[subfield.name.name]),
+ ]
+ )
+ new_existence_condition = ir_data_utils.copy(
+ _ANONYMOUS_BITS_ALIAS_EXISTENCE_SKELETON
+ )
+ existence_clauses = ir_data_utils.builder(
+ new_existence_condition
+ ).function.args
+ existence_clauses[0].function.args[0].field_reference.CopyFrom(
+ anonymous_field_reference
+ )
+ existence_clauses[1].function.args[0].field_reference.CopyFrom(
+ alias_field_reference
+ )
+ new_read_transform = ir_data.Expression(
+ field_reference=ir_data_utils.copy(alias_field_reference)
+ )
+ # This treats *most* of the alias field as synthetic, but not its name(s):
+ # leaving the name(s) as "real" means that symbol collisions with the
+ # surrounding structure will be properly reported to the user.
+ _mark_as_synthetic(new_existence_condition)
+ _mark_as_synthetic(new_read_transform)
+ new_alias = ir_data.Field(
+ read_transform=new_read_transform,
+ existence_condition=new_existence_condition,
+ name=ir_data_utils.copy(subfield.name),
+ )
+ if subfield.HasField("abbreviation"):
+ ir_data_utils.builder(new_alias).abbreviation.CopyFrom(
+ subfield.abbreviation
+ )
+ _mark_as_synthetic(new_alias.existence_condition)
+ _mark_as_synthetic(new_alias.read_transform)
+ new_fields.append(new_alias)
+ # Since the alias field's name(s) are "real," it is important to mark the
+ # original field's name(s) as synthetic, to avoid duplicate error
+ # messages.
+ _mark_as_synthetic(subfield.name)
+ if subfield.HasField("abbreviation"):
+ _mark_as_synthetic(subfield.abbreviation)
+ del structure.field[:]
+ structure.field.extend(new_fields)
_SIZE_BOUNDS = {
"$max_size_in_bits": expression_parser.parse("$upper_bound($size_in_bits)"),
"$min_size_in_bits": expression_parser.parse("$lower_bound($size_in_bits)"),
- "$max_size_in_bytes": expression_parser.parse(
- "$upper_bound($size_in_bytes)"),
- "$min_size_in_bytes": expression_parser.parse(
- "$lower_bound($size_in_bytes)"),
+ "$max_size_in_bytes": expression_parser.parse("$upper_bound($size_in_bytes)"),
+ "$min_size_in_bytes": expression_parser.parse("$lower_bound($size_in_bytes)"),
}
def _add_size_bound_virtuals(structure, type_definition):
- """Adds ${min,max}_size_in_{bits,bytes} virtual fields to structure."""
- names = {
- ir_data.AddressableUnit.BIT: ("$max_size_in_bits", "$min_size_in_bits"),
- ir_data.AddressableUnit.BYTE: ("$max_size_in_bytes", "$min_size_in_bytes"),
- }
- for name in names[type_definition.addressable_unit]:
- bound_field = ir_data.Field(
- read_transform=_SIZE_BOUNDS[name],
- name=ir_data.NameDefinition(name=ir_data.Word(text=name)),
- existence_condition=expression_parser.parse("true"),
- attribute=[_skip_text_output_attribute()]
- )
- _mark_as_synthetic(bound_field.read_transform)
- structure.field.extend([bound_field])
+ """Adds ${min,max}_size_in_{bits,bytes} virtual fields to structure."""
+ names = {
+ ir_data.AddressableUnit.BIT: ("$max_size_in_bits", "$min_size_in_bits"),
+ ir_data.AddressableUnit.BYTE: ("$max_size_in_bytes", "$min_size_in_bytes"),
+ }
+ for name in names[type_definition.addressable_unit]:
+ bound_field = ir_data.Field(
+ read_transform=_SIZE_BOUNDS[name],
+ name=ir_data.NameDefinition(name=ir_data.Word(text=name)),
+ existence_condition=expression_parser.parse("true"),
+ attribute=[_skip_text_output_attribute()],
+ )
+ _mark_as_synthetic(bound_field.read_transform)
+ structure.field.extend([bound_field])
# Each non-virtual field in a structure generates a clause that is passed to
@@ -177,42 +189,43 @@
# physical fields don't end up with a zero-argument `$max()` call, which would
# fail type checking.
_SIZE_CLAUSE_SKELETON = expression_parser.parse(
- "existence_condition ? start + size : 0")
+ "existence_condition ? start + size : 0"
+)
_SIZE_SKELETON = expression_parser.parse("$max(0)")
def _add_size_virtuals(structure, type_definition):
- """Adds a $size_in_bits or $size_in_bytes virtual field to structure."""
- names = {
- ir_data.AddressableUnit.BIT: "$size_in_bits",
- ir_data.AddressableUnit.BYTE: "$size_in_bytes",
- }
- size_field_name = names[type_definition.addressable_unit]
- size_clauses = []
- for field in structure.field:
- # Virtual fields do not have a physical location, and thus do not contribute
- # to the size of the structure.
- if ir_util.field_is_virtual(field):
- continue
- size_clause_ir = ir_data_utils.copy(_SIZE_CLAUSE_SKELETON)
- size_clause = ir_data_utils.builder(size_clause_ir)
- # Copy the appropriate clauses into `existence_condition ? start + size : 0`
- size_clause.function.args[0].CopyFrom(field.existence_condition)
- size_clause.function.args[1].function.args[0].CopyFrom(field.location.start)
- size_clause.function.args[1].function.args[1].CopyFrom(field.location.size)
- size_clauses.append(size_clause_ir)
- size_expression = ir_data_utils.copy(_SIZE_SKELETON)
- size_expression.function.args.extend(size_clauses)
- _mark_as_synthetic(size_expression)
- size_field = ir_data.Field(
- read_transform=size_expression,
- name=ir_data.NameDefinition(name=ir_data.Word(text=size_field_name)),
- existence_condition=ir_data.Expression(
- boolean_constant=ir_data.BooleanConstant(value=True)
- ),
- attribute=[_skip_text_output_attribute()]
- )
- structure.field.extend([size_field])
+ """Adds a $size_in_bits or $size_in_bytes virtual field to structure."""
+ names = {
+ ir_data.AddressableUnit.BIT: "$size_in_bits",
+ ir_data.AddressableUnit.BYTE: "$size_in_bytes",
+ }
+ size_field_name = names[type_definition.addressable_unit]
+ size_clauses = []
+ for field in structure.field:
+ # Virtual fields do not have a physical location, and thus do not contribute
+ # to the size of the structure.
+ if ir_util.field_is_virtual(field):
+ continue
+ size_clause_ir = ir_data_utils.copy(_SIZE_CLAUSE_SKELETON)
+ size_clause = ir_data_utils.builder(size_clause_ir)
+ # Copy the appropriate clauses into `existence_condition ? start + size : 0`
+ size_clause.function.args[0].CopyFrom(field.existence_condition)
+ size_clause.function.args[1].function.args[0].CopyFrom(field.location.start)
+ size_clause.function.args[1].function.args[1].CopyFrom(field.location.size)
+ size_clauses.append(size_clause_ir)
+ size_expression = ir_data_utils.copy(_SIZE_SKELETON)
+ size_expression.function.args.extend(size_clauses)
+ _mark_as_synthetic(size_expression)
+ size_field = ir_data.Field(
+ read_transform=size_expression,
+ name=ir_data.NameDefinition(name=ir_data.Word(text=size_field_name)),
+ existence_condition=ir_data.Expression(
+ boolean_constant=ir_data.BooleanConstant(value=True)
+ ),
+ attribute=[_skip_text_output_attribute()],
+ )
+ structure.field.extend([size_field])
# The replacement for the "$next" keyword is a simple "start + size" expression.
@@ -220,113 +233,134 @@
_NEXT_KEYWORD_REPLACEMENT_EXPRESSION = expression_parser.parse("x + y")
-def _maybe_replace_next_keyword_in_expression(expression_ir, last_location,
- source_file_name, errors):
- if not expression_ir.HasField("builtin_reference"):
- return
- if ir_data_utils.reader(expression_ir).builtin_reference.canonical_name.object_path[0] != "$next":
- return
- if not last_location:
- errors.append([
- error.error(source_file_name, expression_ir.source_location,
- "`$next` may not be used in the first physical field of a " +
- "structure; perhaps you meant `0`?")
- ])
- return
- original_location = expression_ir.source_location
- expression = ir_data_utils.builder(expression_ir)
- expression.CopyFrom(_NEXT_KEYWORD_REPLACEMENT_EXPRESSION)
- expression.function.args[0].CopyFrom(last_location.start)
- expression.function.args[1].CopyFrom(last_location.size)
- expression.source_location.CopyFrom(original_location)
- _mark_as_synthetic(expression.function)
+def _maybe_replace_next_keyword_in_expression(
+ expression_ir, last_location, source_file_name, errors
+):
+ if not expression_ir.HasField("builtin_reference"):
+ return
+ if (
+ ir_data_utils.reader(
+ expression_ir
+ ).builtin_reference.canonical_name.object_path[0]
+ != "$next"
+ ):
+ return
+ if not last_location:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ expression_ir.source_location,
+ "`$next` may not be used in the first physical field of a "
+ + "structure; perhaps you meant `0`?",
+ )
+ ]
+ )
+ return
+ original_location = expression_ir.source_location
+ expression = ir_data_utils.builder(expression_ir)
+ expression.CopyFrom(_NEXT_KEYWORD_REPLACEMENT_EXPRESSION)
+ expression.function.args[0].CopyFrom(last_location.start)
+ expression.function.args[1].CopyFrom(last_location.size)
+ expression.source_location.CopyFrom(original_location)
+ _mark_as_synthetic(expression.function)
def _check_for_bad_next_keyword_in_size(expression, source_file_name, errors):
- if not expression.HasField("builtin_reference"):
- return
- if expression.builtin_reference.canonical_name.object_path[0] != "$next":
- return
- errors.append([
- error.error(source_file_name, expression.source_location,
- "`$next` may only be used in the start expression of a " +
- "physical field.")
- ])
+ if not expression.HasField("builtin_reference"):
+ return
+ if expression.builtin_reference.canonical_name.object_path[0] != "$next":
+ return
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ expression.source_location,
+ "`$next` may only be used in the start expression of a "
+ + "physical field.",
+ )
+ ]
+ )
def _replace_next_keyword(structure, source_file_name, errors):
- last_physical_field_location = None
- new_errors = []
- for field in structure.field:
- if ir_util.field_is_virtual(field):
- # TODO(bolms): It could be useful to allow `$next` in a virtual field, in
- # order to reuse the value (say, to allow overlapping fields in a
- # mostly-packed structure), but it seems better to add `$end_of(field)`,
- # `$offset_of(field)`, and `$size_of(field)` constructs of some sort,
- # instead.
- continue
- traverse_ir.fast_traverse_node_top_down(
- field.location.size, [ir_data.Expression],
- _check_for_bad_next_keyword_in_size,
- parameters={
- "errors": new_errors,
- "source_file_name": source_file_name,
- })
- # If `$next` is misused in a field size, it can end up causing a
- # `RecursionError` in fast_traverse_node_top_down. (When the `$next` node
- # in the next field is replaced, its replacement gets traversed, but the
- # replacement also contains a `$next` node, leading to infinite recursion.)
- #
- # Technically, we could scan all of the sizes instead of bailing early, but
- # it seems relatively unlikely that someone will have `$next` in multiple
- # sizes and not figure out what is going on relatively quickly.
- if new_errors:
- errors.extend(new_errors)
- return
- traverse_ir.fast_traverse_node_top_down(
- field.location.start, [ir_data.Expression],
- _maybe_replace_next_keyword_in_expression,
- parameters={
- "last_location": last_physical_field_location,
- "errors": new_errors,
- "source_file_name": source_file_name,
- })
- # The only possible error from _maybe_replace_next_keyword_in_expression is
- # `$next` occurring in the start expression of the first physical field,
- # which leads to similar recursion issue if `$next` is used in the start
- # expression of the next physical field.
- if new_errors:
- errors.extend(new_errors)
- return
- last_physical_field_location = field.location
+ last_physical_field_location = None
+ new_errors = []
+ for field in structure.field:
+ if ir_util.field_is_virtual(field):
+ # TODO(bolms): It could be useful to allow `$next` in a virtual field, in
+ # order to reuse the value (say, to allow overlapping fields in a
+ # mostly-packed structure), but it seems better to add `$end_of(field)`,
+ # `$offset_of(field)`, and `$size_of(field)` constructs of some sort,
+ # instead.
+ continue
+ traverse_ir.fast_traverse_node_top_down(
+ field.location.size,
+ [ir_data.Expression],
+ _check_for_bad_next_keyword_in_size,
+ parameters={
+ "errors": new_errors,
+ "source_file_name": source_file_name,
+ },
+ )
+ # If `$next` is misused in a field size, it can end up causing a
+ # `RecursionError` in fast_traverse_node_top_down. (When the `$next` node
+ # in the next field is replaced, its replacement gets traversed, but the
+ # replacement also contains a `$next` node, leading to infinite recursion.)
+ #
+ # Technically, we could scan all of the sizes instead of bailing early, but
+ # it seems relatively unlikely that someone will have `$next` in multiple
+ # sizes and not figure out what is going on relatively quickly.
+ if new_errors:
+ errors.extend(new_errors)
+ return
+ traverse_ir.fast_traverse_node_top_down(
+ field.location.start,
+ [ir_data.Expression],
+ _maybe_replace_next_keyword_in_expression,
+ parameters={
+ "last_location": last_physical_field_location,
+ "errors": new_errors,
+ "source_file_name": source_file_name,
+ },
+ )
+ # The only possible error from _maybe_replace_next_keyword_in_expression is
+ # `$next` occurring in the start expression of the first physical field,
+ # which leads to similar recursion issue if `$next` is used in the start
+ # expression of the next physical field.
+ if new_errors:
+ errors.extend(new_errors)
+ return
+ last_physical_field_location = field.location
def _add_virtuals_to_structure(structure, type_definition):
- _add_anonymous_aliases(structure, type_definition)
- _add_size_virtuals(structure, type_definition)
- _add_size_bound_virtuals(structure, type_definition)
+ _add_anonymous_aliases(structure, type_definition)
+ _add_size_virtuals(structure, type_definition)
+ _add_size_bound_virtuals(structure, type_definition)
def desugar(ir):
- """Translates pure syntactic sugar to its desugared form.
+ """Translates pure syntactic sugar to its desugared form.
- Replaces `$next` symbols with the start+length of the previous physical
- field.
+ Replaces `$next` symbols with the start+length of the previous physical
+ field.
- Adds aliases for all fields in anonymous `bits` to the enclosing structure.
+ Adds aliases for all fields in anonymous `bits` to the enclosing structure.
- Arguments:
- ir: The IR to desugar.
+ Arguments:
+ ir: The IR to desugar.
- Returns:
- A list of errors, or an empty list.
- """
- errors = []
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Structure], _replace_next_keyword,
- parameters={"errors": errors})
- if errors:
- return errors
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Structure], _add_virtuals_to_structure)
- return []
+ Returns:
+ A list of errors, or an empty list.
+ """
+ errors = []
+ traverse_ir.fast_traverse_ir_top_down(
+ ir, [ir_data.Structure], _replace_next_keyword, parameters={"errors": errors}
+ )
+ if errors:
+ return errors
+ traverse_ir.fast_traverse_ir_top_down(
+ ir, [ir_data.Structure], _add_virtuals_to_structure
+ )
+ return []
diff --git a/compiler/front_end/synthetics_test.py b/compiler/front_end/synthetics_test.py
index 85a3dfb..ac8e671 100644
--- a/compiler/front_end/synthetics_test.py
+++ b/compiler/front_end/synthetics_test.py
@@ -24,246 +24,300 @@
class SyntheticsTest(unittest.TestCase):
- def _find_attribute(self, field, name):
- result = None
- for attribute in field.attribute:
- if attribute.name.text == name:
- self.assertIsNone(result)
- result = attribute
- self.assertIsNotNone(result)
- return result
+ def _find_attribute(self, field, name):
+ result = None
+ for attribute in field.attribute:
+ if attribute.name.text == name:
+ self.assertIsNone(result)
+ result = attribute
+ self.assertIsNotNone(result)
+ return result
- def _make_ir(self, emb_text):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- "m.emb",
- test_util.dict_file_reader({"m.emb": emb_text}),
- stop_before_step="desugar")
- assert not errors, errors
- return ir
+ def _make_ir(self, emb_text):
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ "m.emb",
+ test_util.dict_file_reader({"m.emb": emb_text}),
+ stop_before_step="desugar",
+ )
+ assert not errors, errors
+ return ir
- def test_nothing_to_do(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] UInt:8[] y\n")
- self.assertEqual([], synthetics.desugar(ir))
+ def test_nothing_to_do(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " 1 [+1] UInt:8[] y\n"
+ )
+ self.assertEqual([], synthetics.desugar(ir))
- def test_adds_anonymous_bits_fields(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] bits:\n"
- " 0 [+4] Bar bar\n"
- " 4 [+4] UInt uint\n"
- " 1 [+1] bits:\n"
- " 0 [+4] Bits nested_bits\n"
- "enum Bar:\n"
- " BAR = 0\n"
- "bits Bits:\n"
- " 0 [+4] UInt uint\n")
- self.assertEqual([], synthetics.desugar(ir))
- structure = ir.module[0].type[0].structure
- # The first field should be the anonymous bits structure.
- self.assertTrue(structure.field[0].HasField("location"))
- # Then the aliases generated for those structures.
- self.assertEqual("bar", structure.field[1].name.name.text)
- self.assertEqual("uint", structure.field[2].name.name.text)
- # Then the second anonymous bits.
- self.assertTrue(structure.field[3].HasField("location"))
- # Then the alias from the second anonymous bits.
- self.assertEqual("nested_bits", structure.field[4].name.name.text)
+ def test_adds_anonymous_bits_fields(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] bits:\n"
+ " 0 [+4] Bar bar\n"
+ " 4 [+4] UInt uint\n"
+ " 1 [+1] bits:\n"
+ " 0 [+4] Bits nested_bits\n"
+ "enum Bar:\n"
+ " BAR = 0\n"
+ "bits Bits:\n"
+ " 0 [+4] UInt uint\n"
+ )
+ self.assertEqual([], synthetics.desugar(ir))
+ structure = ir.module[0].type[0].structure
+ # The first field should be the anonymous bits structure.
+ self.assertTrue(structure.field[0].HasField("location"))
+ # Then the aliases generated for those structures.
+ self.assertEqual("bar", structure.field[1].name.name.text)
+ self.assertEqual("uint", structure.field[2].name.name.text)
+ # Then the second anonymous bits.
+ self.assertTrue(structure.field[3].HasField("location"))
+ # Then the alias from the second anonymous bits.
+ self.assertEqual("nested_bits", structure.field[4].name.name.text)
- def test_adds_correct_existence_condition(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] bits:\n"
- " 0 [+4] UInt bar\n")
- self.assertEqual([], synthetics.desugar(ir))
- bits_field = ir.module[0].type[0].structure.field[0]
- alias_field = ir.module[0].type[0].structure.field[1]
- self.assertEqual("bar", alias_field.name.name.text)
- self.assertEqual(bits_field.name.name.text,
- alias_field.existence_condition.function.args[0].function.
- args[0].field_reference.path[0].source_name[-1].text)
- self.assertEqual(bits_field.name.name.text,
- alias_field.existence_condition.function.args[1].function.
- args[0].field_reference.path[0].source_name[-1].text)
- self.assertEqual("bar",
- alias_field.existence_condition.function.args[1].function.
- args[0].field_reference.path[1].source_name[-1].text)
- self.assertEqual(
- ir_data.FunctionMapping.PRESENCE,
- alias_field.existence_condition.function.args[0].function.function)
- self.assertEqual(
- ir_data.FunctionMapping.PRESENCE,
- alias_field.existence_condition.function.args[1].function.function)
- self.assertEqual(ir_data.FunctionMapping.AND,
- alias_field.existence_condition.function.function)
+ def test_adds_correct_existence_condition(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] bits:\n" " 0 [+4] UInt bar\n"
+ )
+ self.assertEqual([], synthetics.desugar(ir))
+ bits_field = ir.module[0].type[0].structure.field[0]
+ alias_field = ir.module[0].type[0].structure.field[1]
+ self.assertEqual("bar", alias_field.name.name.text)
+ self.assertEqual(
+ bits_field.name.name.text,
+ alias_field.existence_condition.function.args[0]
+ .function.args[0]
+ .field_reference.path[0]
+ .source_name[-1]
+ .text,
+ )
+ self.assertEqual(
+ bits_field.name.name.text,
+ alias_field.existence_condition.function.args[1]
+ .function.args[0]
+ .field_reference.path[0]
+ .source_name[-1]
+ .text,
+ )
+ self.assertEqual(
+ "bar",
+ alias_field.existence_condition.function.args[1]
+ .function.args[0]
+ .field_reference.path[1]
+ .source_name[-1]
+ .text,
+ )
+ self.assertEqual(
+ ir_data.FunctionMapping.PRESENCE,
+ alias_field.existence_condition.function.args[0].function.function,
+ )
+ self.assertEqual(
+ ir_data.FunctionMapping.PRESENCE,
+ alias_field.existence_condition.function.args[1].function.function,
+ )
+ self.assertEqual(
+ ir_data.FunctionMapping.AND,
+ alias_field.existence_condition.function.function,
+ )
- def test_adds_correct_read_transform(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] bits:\n"
- " 0 [+4] UInt bar\n")
- self.assertEqual([], synthetics.desugar(ir))
- bits_field = ir.module[0].type[0].structure.field[0]
- alias_field = ir.module[0].type[0].structure.field[1]
- self.assertEqual("bar", alias_field.name.name.text)
- self.assertEqual(
- bits_field.name.name.text,
- alias_field.read_transform.field_reference.path[0].source_name[-1].text)
- self.assertEqual(
- "bar",
- alias_field.read_transform.field_reference.path[1].source_name[-1].text)
+ def test_adds_correct_read_transform(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] bits:\n" " 0 [+4] UInt bar\n"
+ )
+ self.assertEqual([], synthetics.desugar(ir))
+ bits_field = ir.module[0].type[0].structure.field[0]
+ alias_field = ir.module[0].type[0].structure.field[1]
+ self.assertEqual("bar", alias_field.name.name.text)
+ self.assertEqual(
+ bits_field.name.name.text,
+ alias_field.read_transform.field_reference.path[0].source_name[-1].text,
+ )
+ self.assertEqual(
+ "bar",
+ alias_field.read_transform.field_reference.path[1].source_name[-1].text,
+ )
- def test_adds_correct_abbreviation(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] bits:\n"
- " 0 [+4] UInt bar\n"
- " 4 [+4] UInt baz (qux)\n")
- self.assertEqual([], synthetics.desugar(ir))
- bar_alias = ir.module[0].type[0].structure.field[1]
- baz_alias = ir.module[0].type[0].structure.field[2]
- self.assertFalse(bar_alias.HasField("abbreviation"))
- self.assertEqual("qux", baz_alias.abbreviation.text)
+ def test_adds_correct_abbreviation(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] bits:\n"
+ " 0 [+4] UInt bar\n"
+ " 4 [+4] UInt baz (qux)\n"
+ )
+ self.assertEqual([], synthetics.desugar(ir))
+ bar_alias = ir.module[0].type[0].structure.field[1]
+ baz_alias = ir.module[0].type[0].structure.field[2]
+ self.assertFalse(bar_alias.HasField("abbreviation"))
+ self.assertEqual("qux", baz_alias.abbreviation.text)
- def test_anonymous_bits_sets_correct_is_synthetic(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] bits:\n"
- " 0 [+4] UInt bar (b)\n")
- self.assertEqual([], synthetics.desugar(ir))
- bits_field = ir.module[0].type[0].subtype[0].structure.field[0]
- alias_field = ir.module[0].type[0].structure.field[1]
- self.assertFalse(alias_field.name.source_location.is_synthetic)
- self.assertTrue(alias_field.HasField("abbreviation"))
- self.assertFalse(alias_field.abbreviation.source_location.is_synthetic)
- self.assertTrue(alias_field.HasField("read_transform"))
- read_alias = alias_field.read_transform
- self.assertTrue(read_alias.source_location.is_synthetic)
- self.assertTrue(
- read_alias.field_reference.path[0].source_location.is_synthetic)
- alias_condition = alias_field.existence_condition
- self.assertTrue(alias_condition.source_location.is_synthetic)
- self.assertTrue(
- alias_condition.function.args[0].source_location.is_synthetic)
- self.assertTrue(bits_field.name.source_location.is_synthetic)
- self.assertTrue(bits_field.name.name.source_location.is_synthetic)
- self.assertTrue(bits_field.abbreviation.source_location.is_synthetic)
+ def test_anonymous_bits_sets_correct_is_synthetic(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] bits:\n" " 0 [+4] UInt bar (b)\n"
+ )
+ self.assertEqual([], synthetics.desugar(ir))
+ bits_field = ir.module[0].type[0].subtype[0].structure.field[0]
+ alias_field = ir.module[0].type[0].structure.field[1]
+ self.assertFalse(alias_field.name.source_location.is_synthetic)
+ self.assertTrue(alias_field.HasField("abbreviation"))
+ self.assertFalse(alias_field.abbreviation.source_location.is_synthetic)
+ self.assertTrue(alias_field.HasField("read_transform"))
+ read_alias = alias_field.read_transform
+ self.assertTrue(read_alias.source_location.is_synthetic)
+ self.assertTrue(read_alias.field_reference.path[0].source_location.is_synthetic)
+ alias_condition = alias_field.existence_condition
+ self.assertTrue(alias_condition.source_location.is_synthetic)
+ self.assertTrue(alias_condition.function.args[0].source_location.is_synthetic)
+ self.assertTrue(bits_field.name.source_location.is_synthetic)
+ self.assertTrue(bits_field.name.name.source_location.is_synthetic)
+ self.assertTrue(bits_field.abbreviation.source_location.is_synthetic)
- def test_adds_text_output_skip_attribute_to_anonymous_bits(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] bits:\n"
- " 0 [+4] UInt bar (b)\n")
- self.assertEqual([], synthetics.desugar(ir))
- bits_field = ir.module[0].type[0].structure.field[0]
- text_output_attribute = self._find_attribute(bits_field, "text_output")
- self.assertEqual("Skip", text_output_attribute.value.string_constant.text)
+ def test_adds_text_output_skip_attribute_to_anonymous_bits(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] bits:\n" " 0 [+4] UInt bar (b)\n"
+ )
+ self.assertEqual([], synthetics.desugar(ir))
+ bits_field = ir.module[0].type[0].structure.field[0]
+ text_output_attribute = self._find_attribute(bits_field, "text_output")
+ self.assertEqual("Skip", text_output_attribute.value.string_constant.text)
- def test_skip_attribute_is_marked_as_synthetic(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] bits:\n"
- " 0 [+4] UInt bar\n")
- self.assertEqual([], synthetics.desugar(ir))
- bits_field = ir.module[0].type[0].structure.field[0]
- attribute = self._find_attribute(bits_field, "text_output")
- self.assertTrue(attribute.source_location.is_synthetic)
- self.assertTrue(attribute.name.source_location.is_synthetic)
- self.assertTrue(attribute.value.source_location.is_synthetic)
- self.assertTrue(
- attribute.value.string_constant.source_location.is_synthetic)
+ def test_skip_attribute_is_marked_as_synthetic(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] bits:\n" " 0 [+4] UInt bar\n"
+ )
+ self.assertEqual([], synthetics.desugar(ir))
+ bits_field = ir.module[0].type[0].structure.field[0]
+ attribute = self._find_attribute(bits_field, "text_output")
+ self.assertTrue(attribute.source_location.is_synthetic)
+ self.assertTrue(attribute.name.source_location.is_synthetic)
+ self.assertTrue(attribute.value.source_location.is_synthetic)
+ self.assertTrue(attribute.value.string_constant.source_location.is_synthetic)
- def test_adds_size_in_bytes(self):
- ir = self._make_ir("struct Foo:\n"
- " 1 [+l] UInt:8[] bytes\n"
- " 0 [+1] UInt length (l)\n")
- self.assertEqual([], synthetics.desugar(ir))
- structure = ir.module[0].type[0].structure
- size_in_bytes_field = structure.field[2]
- max_size_in_bytes_field = structure.field[3]
- min_size_in_bytes_field = structure.field[4]
- self.assertEqual("$size_in_bytes", size_in_bytes_field.name.name.text)
- self.assertEqual(ir_data.FunctionMapping.MAXIMUM,
- size_in_bytes_field.read_transform.function.function)
- self.assertEqual("$max_size_in_bytes",
- max_size_in_bytes_field.name.name.text)
- self.assertEqual(ir_data.FunctionMapping.UPPER_BOUND,
- max_size_in_bytes_field.read_transform.function.function)
- self.assertEqual("$min_size_in_bytes",
- min_size_in_bytes_field.name.name.text)
- self.assertEqual(ir_data.FunctionMapping.LOWER_BOUND,
- min_size_in_bytes_field.read_transform.function.function)
- # The correctness of $size_in_bytes et al are tested much further down
- # stream, in tests of the generated C++ code.
+ def test_adds_size_in_bytes(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 1 [+l] UInt:8[] bytes\n"
+ " 0 [+1] UInt length (l)\n"
+ )
+ self.assertEqual([], synthetics.desugar(ir))
+ structure = ir.module[0].type[0].structure
+ size_in_bytes_field = structure.field[2]
+ max_size_in_bytes_field = structure.field[3]
+ min_size_in_bytes_field = structure.field[4]
+ self.assertEqual("$size_in_bytes", size_in_bytes_field.name.name.text)
+ self.assertEqual(
+ ir_data.FunctionMapping.MAXIMUM,
+ size_in_bytes_field.read_transform.function.function,
+ )
+ self.assertEqual("$max_size_in_bytes", max_size_in_bytes_field.name.name.text)
+ self.assertEqual(
+ ir_data.FunctionMapping.UPPER_BOUND,
+ max_size_in_bytes_field.read_transform.function.function,
+ )
+ self.assertEqual("$min_size_in_bytes", min_size_in_bytes_field.name.name.text)
+ self.assertEqual(
+ ir_data.FunctionMapping.LOWER_BOUND,
+ min_size_in_bytes_field.read_transform.function.function,
+ )
+ # The correctness of $size_in_bytes et al are tested much further down
+ # stream, in tests of the generated C++ code.
- def test_adds_size_in_bits(self):
- ir = self._make_ir("bits Foo:\n"
- " 1 [+9] UInt hi\n"
- " 0 [+1] Flag lo\n")
- self.assertEqual([], synthetics.desugar(ir))
- structure = ir.module[0].type[0].structure
- size_in_bits_field = structure.field[2]
- max_size_in_bits_field = structure.field[3]
- min_size_in_bits_field = structure.field[4]
- self.assertEqual("$size_in_bits", size_in_bits_field.name.name.text)
- self.assertEqual(ir_data.FunctionMapping.MAXIMUM,
- size_in_bits_field.read_transform.function.function)
- self.assertEqual("$max_size_in_bits",
- max_size_in_bits_field.name.name.text)
- self.assertEqual(ir_data.FunctionMapping.UPPER_BOUND,
- max_size_in_bits_field.read_transform.function.function)
- self.assertEqual("$min_size_in_bits",
- min_size_in_bits_field.name.name.text)
- self.assertEqual(ir_data.FunctionMapping.LOWER_BOUND,
- min_size_in_bits_field.read_transform.function.function)
- # The correctness of $size_in_bits et al are tested much further down
- # stream, in tests of the generated C++ code.
+ def test_adds_size_in_bits(self):
+ ir = self._make_ir("bits Foo:\n" " 1 [+9] UInt hi\n" " 0 [+1] Flag lo\n")
+ self.assertEqual([], synthetics.desugar(ir))
+ structure = ir.module[0].type[0].structure
+ size_in_bits_field = structure.field[2]
+ max_size_in_bits_field = structure.field[3]
+ min_size_in_bits_field = structure.field[4]
+ self.assertEqual("$size_in_bits", size_in_bits_field.name.name.text)
+ self.assertEqual(
+ ir_data.FunctionMapping.MAXIMUM,
+ size_in_bits_field.read_transform.function.function,
+ )
+ self.assertEqual("$max_size_in_bits", max_size_in_bits_field.name.name.text)
+ self.assertEqual(
+ ir_data.FunctionMapping.UPPER_BOUND,
+ max_size_in_bits_field.read_transform.function.function,
+ )
+ self.assertEqual("$min_size_in_bits", min_size_in_bits_field.name.name.text)
+ self.assertEqual(
+ ir_data.FunctionMapping.LOWER_BOUND,
+ min_size_in_bits_field.read_transform.function.function,
+ )
+ # The correctness of $size_in_bits et al are tested much further down
+ # stream, in tests of the generated C++ code.
- def test_adds_text_output_skip_attribute_to_size_in_bytes(self):
- ir = self._make_ir("struct Foo:\n"
- " 1 [+l] UInt:8[] bytes\n"
- " 0 [+1] UInt length (l)\n")
- self.assertEqual([], synthetics.desugar(ir))
- size_in_bytes_field = ir.module[0].type[0].structure.field[2]
- self.assertEqual("$size_in_bytes", size_in_bytes_field.name.name.text)
- text_output_attribute = self._find_attribute(size_in_bytes_field,
- "text_output")
- self.assertEqual("Skip", text_output_attribute.value.string_constant.text)
+ def test_adds_text_output_skip_attribute_to_size_in_bytes(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 1 [+l] UInt:8[] bytes\n"
+ " 0 [+1] UInt length (l)\n"
+ )
+ self.assertEqual([], synthetics.desugar(ir))
+ size_in_bytes_field = ir.module[0].type[0].structure.field[2]
+ self.assertEqual("$size_in_bytes", size_in_bytes_field.name.name.text)
+ text_output_attribute = self._find_attribute(size_in_bytes_field, "text_output")
+ self.assertEqual("Skip", text_output_attribute.value.string_constant.text)
- def test_replaces_next(self):
- ir = self._make_ir("struct Foo:\n"
- " 1 [+2] UInt:8[] a\n"
- " $next [+4] UInt b\n"
- " $next [+1] UInt c\n")
- self.assertEqual([], synthetics.desugar(ir))
- offset_of_b = ir.module[0].type[0].structure.field[1].location.start
- self.assertTrue(offset_of_b.HasField("function"))
- self.assertEqual(offset_of_b.function.function, ir_data.FunctionMapping.ADDITION)
- self.assertEqual(offset_of_b.function.args[0].constant.value, "1")
- self.assertEqual(offset_of_b.function.args[1].constant.value, "2")
- offset_of_c = ir.module[0].type[0].structure.field[2].location.start
- self.assertEqual(
- offset_of_c.function.args[0].function.args[0].constant.value, "1")
- self.assertEqual(
- offset_of_c.function.args[0].function.args[1].constant.value, "2")
- self.assertEqual(offset_of_c.function.args[1].constant.value, "4")
+ def test_replaces_next(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 1 [+2] UInt:8[] a\n"
+ " $next [+4] UInt b\n"
+ " $next [+1] UInt c\n"
+ )
+ self.assertEqual([], synthetics.desugar(ir))
+ offset_of_b = ir.module[0].type[0].structure.field[1].location.start
+ self.assertTrue(offset_of_b.HasField("function"))
+ self.assertEqual(
+ offset_of_b.function.function, ir_data.FunctionMapping.ADDITION
+ )
+ self.assertEqual(offset_of_b.function.args[0].constant.value, "1")
+ self.assertEqual(offset_of_b.function.args[1].constant.value, "2")
+ offset_of_c = ir.module[0].type[0].structure.field[2].location.start
+ self.assertEqual(
+ offset_of_c.function.args[0].function.args[0].constant.value, "1"
+ )
+ self.assertEqual(
+ offset_of_c.function.args[0].function.args[1].constant.value, "2"
+ )
+ self.assertEqual(offset_of_c.function.args[1].constant.value, "4")
- def test_next_in_first_field(self):
- ir = self._make_ir("struct Foo:\n"
- " $next [+2] UInt:8[] a\n"
- " $next [+4] UInt b\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[0].location.start.source_location,
- "`$next` may not be used in the first physical field of " +
- "a structure; perhaps you meant `0`?"),
- ]], synthetics.desugar(ir))
+ def test_next_in_first_field(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " $next [+2] UInt:8[] a\n" " $next [+4] UInt b\n"
+ )
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ struct.field[0].location.start.source_location,
+ "`$next` may not be used in the first physical field of "
+ + "a structure; perhaps you meant `0`?",
+ ),
+ ]
+ ],
+ synthetics.desugar(ir),
+ )
- def test_next_in_size(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+2] UInt:8[] a\n"
- " 1 [+$next] UInt b\n")
- struct = ir.module[0].type[0].structure
- self.assertEqual([[
- error.error("m.emb", struct.field[1].location.size.source_location,
- "`$next` may only be used in the start expression of a " +
- "physical field."),
- ]], synthetics.desugar(ir))
+ def test_next_in_size(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+2] UInt:8[] a\n" " 1 [+$next] UInt b\n"
+ )
+ struct = ir.module[0].type[0].structure
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ struct.field[1].location.size.source_location,
+ "`$next` may only be used in the start expression of a "
+ + "physical field.",
+ ),
+ ]
+ ],
+ synthetics.desugar(ir),
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/tokenizer.py b/compiler/front_end/tokenizer.py
index 752371f..defee34 100644
--- a/compiler/front_end/tokenizer.py
+++ b/compiler/front_end/tokenizer.py
@@ -36,87 +36,123 @@
def tokenize(text, file_name):
- # TODO(bolms): suppress end-of-line, indent, and dedent tokens between matched
- # delimiters ([], (), and {}).
- """Tokenizes its argument.
+ # TODO(bolms): suppress end-of-line, indent, and dedent tokens between matched
+ # delimiters ([], (), and {}).
+ """Tokenizes its argument.
- Arguments:
- text: The raw text of a .emb file.
- file_name: The name of the file to use in errors.
+ Arguments:
+ text: The raw text of a .emb file.
+ file_name: The name of the file to use in errors.
- Returns:
- A tuple of:
- a list of parser_types.Tokens or None
- a possibly-empty list of errors.
- """
- tokens = []
- indent_stack = [""]
- line_number = 0
- for line in text.splitlines():
- line_number += 1
+ Returns:
+ A tuple of:
+ a list of parser_types.Tokens or None
+ a possibly-empty list of errors.
+ """
+ tokens = []
+ indent_stack = [""]
+ line_number = 0
+ for line in text.splitlines():
+ line_number += 1
- # _tokenize_line splits the actual text into tokens.
- line_tokens, errors = _tokenize_line(line, line_number, file_name)
- if errors:
- return None, errors
+ # _tokenize_line splits the actual text into tokens.
+ line_tokens, errors = _tokenize_line(line, line_number, file_name)
+ if errors:
+ return None, errors
- # Lines with only whitespace and comments are not used for Indent/Dedent
- # calculation, and do not produce end-of-line tokens.
- for token in line_tokens:
- if token.symbol != "Comment":
- break
- else:
- tokens.extend(line_tokens)
- tokens.append(parser_types.Token(
- '"\\n"', "\n", parser_types.make_location(
- (line_number, len(line) + 1), (line_number, len(line) + 1))))
- continue
+ # Lines with only whitespace and comments are not used for Indent/Dedent
+ # calculation, and do not produce end-of-line tokens.
+ for token in line_tokens:
+ if token.symbol != "Comment":
+ break
+ else:
+ tokens.extend(line_tokens)
+ tokens.append(
+ parser_types.Token(
+ '"\\n"',
+ "\n",
+ parser_types.make_location(
+ (line_number, len(line) + 1), (line_number, len(line) + 1)
+ ),
+ )
+ )
+ continue
- # Leading whitespace is whatever .lstrip() removes.
- leading_whitespace = line[0:len(line) - len(line.lstrip())]
- if leading_whitespace == indent_stack[-1]:
- # If the current leading whitespace is equal to the last leading
- # whitespace, do not emit an Indent or Dedent token.
- pass
- elif leading_whitespace.startswith(indent_stack[-1]):
- # If the current leading whitespace is longer than the last leading
- # whitespace, emit an Indent token. For the token text, take the new
- # part of the whitespace.
- tokens.append(
- parser_types.Token(
- "Indent", leading_whitespace[len(indent_stack[-1]):],
- parser_types.make_location(
- (line_number, len(indent_stack[-1]) + 1),
- (line_number, len(leading_whitespace) + 1))))
- indent_stack.append(leading_whitespace)
- else:
- # Otherwise, search for the unclosed indentation level that matches
- # the current indentation level. Emit a Dedent token for each
- # newly-closed indentation level.
- for i in range(len(indent_stack) - 1, -1, -1):
- if leading_whitespace == indent_stack[i]:
- break
+ # Leading whitespace is whatever .lstrip() removes.
+ leading_whitespace = line[0 : len(line) - len(line.lstrip())]
+ if leading_whitespace == indent_stack[-1]:
+ # If the current leading whitespace is equal to the last leading
+ # whitespace, do not emit an Indent or Dedent token.
+ pass
+ elif leading_whitespace.startswith(indent_stack[-1]):
+ # If the current leading whitespace is longer than the last leading
+ # whitespace, emit an Indent token. For the token text, take the new
+ # part of the whitespace.
+ tokens.append(
+ parser_types.Token(
+ "Indent",
+ leading_whitespace[len(indent_stack[-1]) :],
+ parser_types.make_location(
+ (line_number, len(indent_stack[-1]) + 1),
+ (line_number, len(leading_whitespace) + 1),
+ ),
+ )
+ )
+ indent_stack.append(leading_whitespace)
+ else:
+ # Otherwise, search for the unclosed indentation level that matches
+ # the current indentation level. Emit a Dedent token for each
+ # newly-closed indentation level.
+ for i in range(len(indent_stack) - 1, -1, -1):
+ if leading_whitespace == indent_stack[i]:
+ break
+ tokens.append(
+ parser_types.Token(
+ "Dedent",
+ "",
+ parser_types.make_location(
+ (line_number, len(leading_whitespace) + 1),
+ (line_number, len(leading_whitespace) + 1),
+ ),
+ )
+ )
+ del indent_stack[i]
+ else:
+ return None, [
+ [
+ error.error(
+ file_name,
+ parser_types.make_location(
+ (line_number, 1),
+ (line_number, len(leading_whitespace) + 1),
+ ),
+ "Bad indentation",
+ )
+ ]
+ ]
+
+ tokens.extend(line_tokens)
+
+ # Append an end-of-line token (for non-whitespace lines).
tokens.append(
- parser_types.Token("Dedent", "", parser_types.make_location(
- (line_number, len(leading_whitespace) + 1),
- (line_number, len(leading_whitespace) + 1))))
- del indent_stack[i]
- else:
- return None, [[error.error(
- file_name, parser_types.make_location(
- (line_number, 1), (line_number, len(leading_whitespace) + 1)),
- "Bad indentation")]]
+ parser_types.Token(
+ '"\\n"',
+ "\n",
+ parser_types.make_location(
+ (line_number, len(line) + 1), (line_number, len(line) + 1)
+ ),
+ )
+ )
+ for i in range(len(indent_stack) - 1):
+ tokens.append(
+ parser_types.Token(
+ "Dedent",
+ "",
+ parser_types.make_location((line_number + 1, 1), (line_number + 1, 1)),
+ )
+ )
+ return tokens, []
- tokens.extend(line_tokens)
-
- # Append an end-of-line token (for non-whitespace lines).
- tokens.append(parser_types.Token(
- '"\\n"', "\n", parser_types.make_location(
- (line_number, len(line) + 1), (line_number, len(line) + 1))))
- for i in range(len(indent_stack) - 1):
- tokens.append(parser_types.Token("Dedent", "", parser_types.make_location(
- (line_number + 1, 1), (line_number + 1, 1))))
- return tokens, []
# Token patterns used by _tokenize_line.
LITERAL_TOKEN_PATTERNS = (
@@ -125,7 +161,8 @@
"$max $present $upper_bound $lower_bound $next "
"$size_in_bits $size_in_bytes "
"$max_size_in_bits $max_size_in_bytes $min_size_in_bits $min_size_in_bytes "
- "$default struct bits enum external import as if let").split()
+ "$default struct bits enum external import as if let"
+).split()
_T = collections.namedtuple("T", ["regex", "symbol"])
REGEX_TOKEN_PATTERNS = [
# Words starting with variations of "emboss reserved" are reserved for
@@ -168,56 +205,68 @@
def _tokenize_line(line, line_number, file_name):
- """Tokenizes a single line of input.
+ """Tokenizes a single line of input.
- Arguments:
- line: The line of text to tokenize.
- line_number: The line number (used when constructing token objects).
- file_name: The name of a file to use in errors.
+ Arguments:
+ line: The line of text to tokenize.
+ line_number: The line number (used when constructing token objects).
+ file_name: The name of a file to use in errors.
- Returns:
- A tuple of:
- A list of token objects or None.
- A possibly-empty list of errors.
- """
- tokens = []
- offset = 0
- while offset < len(line):
- best_candidate = ""
- best_candidate_symbol = None
- # Find the longest match. Ties go to the first match. This way, keywords
- # ("struct") are matched as themselves, but words that only happen to start
- # with keywords ("structure") are matched as words.
- #
- # There is never a reason to try to match a literal after a regex that
- # could also match that literal, so check literals first.
- for literal in LITERAL_TOKEN_PATTERNS:
- if line[offset:].startswith(literal) and len(literal) > len(
- best_candidate):
- best_candidate = literal
- # For Emboss, the name of a literal token is just the literal in quotes,
- # so that the grammar can read a little more naturally, e.g.:
+ Returns:
+ A tuple of:
+ A list of token objects or None.
+ A possibly-empty list of errors.
+ """
+ tokens = []
+ offset = 0
+ while offset < len(line):
+ best_candidate = ""
+ best_candidate_symbol = None
+ # Find the longest match. Ties go to the first match. This way, keywords
+ # ("struct") are matched as themselves, but words that only happen to start
+ # with keywords ("structure") are matched as words.
#
- # expression -> expression "+" expression
- #
- # instead of
- #
- # expression -> expression Plus expression
- best_candidate_symbol = '"' + literal + '"'
- for pattern in REGEX_TOKEN_PATTERNS:
- match_result = pattern.regex.match(line[offset:])
- if match_result and len(match_result.group(0)) > len(best_candidate):
- best_candidate = match_result.group(0)
- best_candidate_symbol = pattern.symbol
- if not best_candidate:
- return None, [[error.error(
- file_name, parser_types.make_location(
- (line_number, offset + 1), (line_number, offset + 2)),
- "Unrecognized token")]]
- if best_candidate_symbol:
- tokens.append(parser_types.Token(
- best_candidate_symbol, best_candidate, parser_types.make_location(
- (line_number, offset + 1),
- (line_number, offset + len(best_candidate) + 1))))
- offset += len(best_candidate)
- return tokens, None
+ # There is never a reason to try to match a literal after a regex that
+ # could also match that literal, so check literals first.
+ for literal in LITERAL_TOKEN_PATTERNS:
+ if line[offset:].startswith(literal) and len(literal) > len(best_candidate):
+ best_candidate = literal
+ # For Emboss, the name of a literal token is just the literal in quotes,
+ # so that the grammar can read a little more naturally, e.g.:
+ #
+ # expression -> expression "+" expression
+ #
+ # instead of
+ #
+ # expression -> expression Plus expression
+ best_candidate_symbol = '"' + literal + '"'
+ for pattern in REGEX_TOKEN_PATTERNS:
+ match_result = pattern.regex.match(line[offset:])
+ if match_result and len(match_result.group(0)) > len(best_candidate):
+ best_candidate = match_result.group(0)
+ best_candidate_symbol = pattern.symbol
+ if not best_candidate:
+ return None, [
+ [
+ error.error(
+ file_name,
+ parser_types.make_location(
+ (line_number, offset + 1), (line_number, offset + 2)
+ ),
+ "Unrecognized token",
+ )
+ ]
+ ]
+ if best_candidate_symbol:
+ tokens.append(
+ parser_types.Token(
+ best_candidate_symbol,
+ best_candidate,
+ parser_types.make_location(
+ (line_number, offset + 1),
+ (line_number, offset + len(best_candidate) + 1),
+ ),
+ )
+ )
+ offset += len(best_candidate)
+ return tokens, None
diff --git a/compiler/front_end/tokenizer_test.py b/compiler/front_end/tokenizer_test.py
index 91e1b3b..6a301e6 100644
--- a/compiler/front_end/tokenizer_test.py
+++ b/compiler/front_end/tokenizer_test.py
@@ -21,362 +21,479 @@
def _token_symbols(token_list):
- """Given a list of tokens, returns a list of their symbol names."""
- return [token.symbol for token in token_list]
+ """Given a list of tokens, returns a list of their symbol names."""
+ return [token.symbol for token in token_list]
class TokenizerTest(unittest.TestCase):
- """Tests for the tokenizer.tokenize function."""
+ """Tests for the tokenizer.tokenize function."""
- def test_bad_indent_tab_versus_space(self):
- # A bad indent is one that doesn't match a previous unmatched indent.
- tokens, errors = tokenizer.tokenize(" a\n\tb", "file")
- self.assertFalse(tokens)
- self.assertEqual([[error.error("file", parser_types.make_location(
- (2, 1), (2, 2)), "Bad indentation")]], errors)
+ def test_bad_indent_tab_versus_space(self):
+ # A bad indent is one that doesn't match a previous unmatched indent.
+ tokens, errors = tokenizer.tokenize(" a\n\tb", "file")
+ self.assertFalse(tokens)
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "file",
+ parser_types.make_location((2, 1), (2, 2)),
+ "Bad indentation",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_bad_indent_tab_versus_eight_spaces(self):
- tokens, errors = tokenizer.tokenize(" a\n\tb", "file")
- self.assertFalse(tokens)
- self.assertEqual([[error.error("file", parser_types.make_location(
- (2, 1), (2, 2)), "Bad indentation")]], errors)
+ def test_bad_indent_tab_versus_eight_spaces(self):
+ tokens, errors = tokenizer.tokenize(" a\n\tb", "file")
+ self.assertFalse(tokens)
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "file",
+ parser_types.make_location((2, 1), (2, 2)),
+ "Bad indentation",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_bad_indent_tab_versus_four_spaces(self):
- tokens, errors = tokenizer.tokenize(" a\n\tb", "file")
- self.assertFalse(tokens)
- self.assertEqual([[error.error("file", parser_types.make_location(
- (2, 1), (2, 2)), "Bad indentation")]], errors)
+ def test_bad_indent_tab_versus_four_spaces(self):
+ tokens, errors = tokenizer.tokenize(" a\n\tb", "file")
+ self.assertFalse(tokens)
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "file",
+ parser_types.make_location((2, 1), (2, 2)),
+ "Bad indentation",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_bad_indent_two_spaces_versus_one_space(self):
- tokens, errors = tokenizer.tokenize(" a\n b", "file")
- self.assertFalse(tokens)
- self.assertEqual([[error.error("file", parser_types.make_location(
- (2, 1), (2, 2)), "Bad indentation")]], errors)
+ def test_bad_indent_two_spaces_versus_one_space(self):
+ tokens, errors = tokenizer.tokenize(" a\n b", "file")
+ self.assertFalse(tokens)
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "file",
+ parser_types.make_location((2, 1), (2, 2)),
+ "Bad indentation",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_bad_indent_matches_closed_indent(self):
- tokens, errors = tokenizer.tokenize(" a\nb\n c\n d", "file")
- self.assertFalse(tokens)
- self.assertEqual([[error.error("file", parser_types.make_location(
- (4, 1), (4, 2)), "Bad indentation")]], errors)
+ def test_bad_indent_matches_closed_indent(self):
+ tokens, errors = tokenizer.tokenize(" a\nb\n c\n d", "file")
+ self.assertFalse(tokens)
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "file",
+ parser_types.make_location((4, 1), (4, 2)),
+ "Bad indentation",
+ )
+ ]
+ ],
+ errors,
+ )
- def test_bad_string_after_string_with_escaped_backslash_at_end(self):
- tokens, errors = tokenizer.tokenize(r'"\\""', "name")
- self.assertFalse(tokens)
- self.assertEqual([[error.error("name", parser_types.make_location(
- (1, 5), (1, 6)), "Unrecognized token")]], errors)
+ def test_bad_string_after_string_with_escaped_backslash_at_end(self):
+ tokens, errors = tokenizer.tokenize(r'"\\""', "name")
+ self.assertFalse(tokens)
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "name",
+ parser_types.make_location((1, 5), (1, 6)),
+ "Unrecognized token",
+ )
+ ]
+ ],
+ errors,
+ )
def _make_short_token_match_tests():
- """Makes tests for short, simple tokenization cases."""
- eol = '"\\n"'
- cases = {
- "Cam": ["CamelWord", eol],
- "Ca9": ["CamelWord", eol],
- "CanB": ["CamelWord", eol],
- "CanBee": ["CamelWord", eol],
- "CBa": ["CamelWord", eol],
- "cam": ["SnakeWord", eol],
- "ca9": ["SnakeWord", eol],
- "can_b": ["SnakeWord", eol],
- "can_bee": ["SnakeWord", eol],
- "c_ba": ["SnakeWord", eol],
- "cba_": ["SnakeWord", eol],
- "c_b_a_": ["SnakeWord", eol],
- "CAM": ["ShoutyWord", eol],
- "CA9": ["ShoutyWord", eol],
- "CAN_B": ["ShoutyWord", eol],
- "CAN_BEE": ["ShoutyWord", eol],
- "C_BA": ["ShoutyWord", eol],
- "C": ["BadWord", eol],
- "C1": ["BadWord", eol],
- "c": ["SnakeWord", eol],
- "$": ["BadWord", eol],
- "_": ["BadWord", eol],
- "_a": ["BadWord", eol],
- "_A": ["BadWord", eol],
- "Cb_A": ["BadWord", eol],
- "aCb": ["BadWord", eol],
- "a b": ["SnakeWord", "SnakeWord", eol],
- "a\tb": ["SnakeWord", "SnakeWord", eol],
- "a \t b ": ["SnakeWord", "SnakeWord", eol],
- " \t ": [eol],
- "a #b": ["SnakeWord", "Comment", eol],
- "a#": ["SnakeWord", "Comment", eol],
- "# b": ["Comment", eol],
- " # b": ["Comment", eol],
- " #": ["Comment", eol],
- "": [],
- "\n": [eol],
- "\na": [eol, "SnakeWord", eol],
- "a--example": ["SnakeWord", "BadDocumentation", eol],
- "a ---- example": ["SnakeWord", "BadDocumentation", eol],
- "a --- example": ["SnakeWord", "BadDocumentation", eol],
- "a-- example": ["SnakeWord", "Documentation", eol],
- "a -- -- example": ["SnakeWord", "Documentation", eol],
- "a -- - example": ["SnakeWord", "Documentation", eol],
- "--": ["Documentation", eol],
- "-- ": ["Documentation", eol],
- "-- ": ["Documentation", eol],
- "$default": ['"$default"', eol],
- "$defaultx": ["BadWord", eol],
- "$def": ["BadWord", eol],
- "x$default": ["BadWord", eol],
- "9$default": ["BadWord", eol],
- "struct": ['"struct"', eol],
- "external": ['"external"', eol],
- "bits": ['"bits"', eol],
- "enum": ['"enum"', eol],
- "as": ['"as"', eol],
- "import": ['"import"', eol],
- "true": ["BooleanConstant", eol],
- "false": ["BooleanConstant", eol],
- "truex": ["SnakeWord", eol],
- "falsex": ["SnakeWord", eol],
- "structx": ["SnakeWord", eol],
- "bitsx": ["SnakeWord", eol],
- "enumx": ["SnakeWord", eol],
- "0b": ["BadNumber", eol],
- "0x": ["BadNumber", eol],
- "0b011101": ["Number", eol],
- "0b0": ["Number", eol],
- "0b0111_1111_0000": ["Number", eol],
- "0b00_000_00": ["BadNumber", eol],
- "0b0_0_0": ["BadNumber", eol],
- "0b0111012": ["BadNumber", eol],
- "0b011101x": ["BadWord", eol],
- "0b011101b": ["BadNumber", eol],
- "0B0": ["BadNumber", eol],
- "0X0": ["BadNumber", eol],
- "0b_": ["BadNumber", eol],
- "0x_": ["BadNumber", eol],
- "0b__": ["BadNumber", eol],
- "0x__": ["BadNumber", eol],
- "0b_0000": ["Number", eol],
- "0b0000_": ["BadNumber", eol],
- "0b00_____00": ["BadNumber", eol],
- "0x00_000_00": ["BadNumber", eol],
- "0x0_0_0": ["BadNumber", eol],
- "0b____0____": ["BadNumber", eol],
- "0b00000000000000000000": ["Number", eol],
- "0b_00000000": ["Number", eol],
- "0b0000_0000_0000": ["Number", eol],
- "0b000_0000_0000": ["Number", eol],
- "0b00_0000_0000": ["Number", eol],
- "0b0_0000_0000": ["Number", eol],
- "0b_0000_0000_0000": ["Number", eol],
- "0b_000_0000_0000": ["Number", eol],
- "0b_00_0000_0000": ["Number", eol],
- "0b_0_0000_0000": ["Number", eol],
- "0b00000000_00000000_00000000": ["Number", eol],
- "0b0000000_00000000_00000000": ["Number", eol],
- "0b000000_00000000_00000000": ["Number", eol],
- "0b00000_00000000_00000000": ["Number", eol],
- "0b0000_00000000_00000000": ["Number", eol],
- "0b000_00000000_00000000": ["Number", eol],
- "0b00_00000000_00000000": ["Number", eol],
- "0b0_00000000_00000000": ["Number", eol],
- "0b_00000000_00000000_00000000": ["Number", eol],
- "0b_0000000_00000000_00000000": ["Number", eol],
- "0b_000000_00000000_00000000": ["Number", eol],
- "0b_00000_00000000_00000000": ["Number", eol],
- "0b_0000_00000000_00000000": ["Number", eol],
- "0b_000_00000000_00000000": ["Number", eol],
- "0b_00_00000000_00000000": ["Number", eol],
- "0b_0_00000000_00000000": ["Number", eol],
- "0x0": ["Number", eol],
- "0x00000000000000000000": ["Number", eol],
- "0x_0000": ["Number", eol],
- "0x_00000000": ["Number", eol],
- "0x0000_0000_0000": ["Number", eol],
- "0x000_0000_0000": ["Number", eol],
- "0x00_0000_0000": ["Number", eol],
- "0x0_0000_0000": ["Number", eol],
- "0x_0000_0000_0000": ["Number", eol],
- "0x_000_0000_0000": ["Number", eol],
- "0x_00_0000_0000": ["Number", eol],
- "0x_0_0000_0000": ["Number", eol],
- "0x00000000_00000000_00000000": ["Number", eol],
- "0x0000000_00000000_00000000": ["Number", eol],
- "0x000000_00000000_00000000": ["Number", eol],
- "0x00000_00000000_00000000": ["Number", eol],
- "0x0000_00000000_00000000": ["Number", eol],
- "0x000_00000000_00000000": ["Number", eol],
- "0x00_00000000_00000000": ["Number", eol],
- "0x0_00000000_00000000": ["Number", eol],
- "0x_00000000_00000000_00000000": ["Number", eol],
- "0x_0000000_00000000_00000000": ["Number", eol],
- "0x_000000_00000000_00000000": ["Number", eol],
- "0x_00000_00000000_00000000": ["Number", eol],
- "0x_0000_00000000_00000000": ["Number", eol],
- "0x_000_00000000_00000000": ["Number", eol],
- "0x_00_00000000_00000000": ["Number", eol],
- "0x_0_00000000_00000000": ["Number", eol],
- "0x__00000000_00000000": ["BadNumber", eol],
- "0x00000000_00000000_0000": ["BadNumber", eol],
- "0x00000000_0000_0000": ["BadNumber", eol],
- "0x_00000000000000000000": ["BadNumber", eol],
- "0b_00000000000000000000": ["BadNumber", eol],
- "0b00000000_00000000_0000": ["BadNumber", eol],
- "0b00000000_0000_0000": ["BadNumber", eol],
- "0x0000_": ["BadNumber", eol],
- "0x00_____00": ["BadNumber", eol],
- "0x____0____": ["BadNumber", eol],
- "EmbossReserved": ["BadWord", eol],
- "EmbossReservedA": ["BadWord", eol],
- "EmbossReserved_": ["BadWord", eol],
- "EMBOSS_RESERVED": ["BadWord", eol],
- "EMBOSS_RESERVED_": ["BadWord", eol],
- "EMBOSS_RESERVEDA": ["BadWord", eol],
- "emboss_reserved": ["BadWord", eol],
- "emboss_reserved_": ["BadWord", eol],
- "emboss_reserveda": ["BadWord", eol],
- "0x0123456789abcdefABCDEF": ["Number", eol],
- "0": ["Number", eol],
- "1": ["Number", eol],
- "1a": ["BadNumber", eol],
- "1g": ["BadWord", eol],
- "1234567890": ["Number", eol],
- "1_234_567_890": ["Number", eol],
- "234_567_890": ["Number", eol],
- "34_567_890": ["Number", eol],
- "4_567_890": ["Number", eol],
- "1_2_3_4_5_6_7_8_9_0": ["BadNumber", eol],
- "1234567890_": ["BadNumber", eol],
- "1__234567890": ["BadNumber", eol],
- "_1234567890": ["BadWord", eol],
- "[]": ['"["', '"]"', eol],
- "()": ['"("', '")"', eol],
- "..": ['"."', '"."', eol],
- "...": ['"."', '"."', '"."', eol],
- "....": ['"."', '"."', '"."', '"."', eol],
- '"abc"': ["String", eol],
- '""': ["String", eol],
- r'"\\"': ["String", eol],
- r'"\""': ["String", eol],
- r'"\n"': ["String", eol],
- r'"\\n"': ["String", eol],
- r'"\\xyz"': ["String", eol],
- r'"\\\\"': ["String", eol],
- }
- for c in ("[ ] ( ) ? : = + - * . == != < <= > >= && || , $max $present "
- "$upper_bound $lower_bound $size_in_bits $size_in_bytes "
- "$max_size_in_bits $max_size_in_bytes $min_size_in_bits "
- "$min_size_in_bytes "
- "$default struct bits enum external import as if let").split():
- cases[c] = ['"' + c + '"', eol]
+ """Makes tests for short, simple tokenization cases."""
+ eol = '"\\n"'
+ cases = {
+ "Cam": ["CamelWord", eol],
+ "Ca9": ["CamelWord", eol],
+ "CanB": ["CamelWord", eol],
+ "CanBee": ["CamelWord", eol],
+ "CBa": ["CamelWord", eol],
+ "cam": ["SnakeWord", eol],
+ "ca9": ["SnakeWord", eol],
+ "can_b": ["SnakeWord", eol],
+ "can_bee": ["SnakeWord", eol],
+ "c_ba": ["SnakeWord", eol],
+ "cba_": ["SnakeWord", eol],
+ "c_b_a_": ["SnakeWord", eol],
+ "CAM": ["ShoutyWord", eol],
+ "CA9": ["ShoutyWord", eol],
+ "CAN_B": ["ShoutyWord", eol],
+ "CAN_BEE": ["ShoutyWord", eol],
+ "C_BA": ["ShoutyWord", eol],
+ "C": ["BadWord", eol],
+ "C1": ["BadWord", eol],
+ "c": ["SnakeWord", eol],
+ "$": ["BadWord", eol],
+ "_": ["BadWord", eol],
+ "_a": ["BadWord", eol],
+ "_A": ["BadWord", eol],
+ "Cb_A": ["BadWord", eol],
+ "aCb": ["BadWord", eol],
+ "a b": ["SnakeWord", "SnakeWord", eol],
+ "a\tb": ["SnakeWord", "SnakeWord", eol],
+ "a \t b ": ["SnakeWord", "SnakeWord", eol],
+ " \t ": [eol],
+ "a #b": ["SnakeWord", "Comment", eol],
+ "a#": ["SnakeWord", "Comment", eol],
+ "# b": ["Comment", eol],
+ " # b": ["Comment", eol],
+ " #": ["Comment", eol],
+ "": [],
+ "\n": [eol],
+ "\na": [eol, "SnakeWord", eol],
+ "a--example": ["SnakeWord", "BadDocumentation", eol],
+ "a ---- example": ["SnakeWord", "BadDocumentation", eol],
+ "a --- example": ["SnakeWord", "BadDocumentation", eol],
+ "a-- example": ["SnakeWord", "Documentation", eol],
+ "a -- -- example": ["SnakeWord", "Documentation", eol],
+ "a -- - example": ["SnakeWord", "Documentation", eol],
+ "--": ["Documentation", eol],
+ "-- ": ["Documentation", eol],
+ "-- ": ["Documentation", eol],
+ "$default": ['"$default"', eol],
+ "$defaultx": ["BadWord", eol],
+ "$def": ["BadWord", eol],
+ "x$default": ["BadWord", eol],
+ "9$default": ["BadWord", eol],
+ "struct": ['"struct"', eol],
+ "external": ['"external"', eol],
+ "bits": ['"bits"', eol],
+ "enum": ['"enum"', eol],
+ "as": ['"as"', eol],
+ "import": ['"import"', eol],
+ "true": ["BooleanConstant", eol],
+ "false": ["BooleanConstant", eol],
+ "truex": ["SnakeWord", eol],
+ "falsex": ["SnakeWord", eol],
+ "structx": ["SnakeWord", eol],
+ "bitsx": ["SnakeWord", eol],
+ "enumx": ["SnakeWord", eol],
+ "0b": ["BadNumber", eol],
+ "0x": ["BadNumber", eol],
+ "0b011101": ["Number", eol],
+ "0b0": ["Number", eol],
+ "0b0111_1111_0000": ["Number", eol],
+ "0b00_000_00": ["BadNumber", eol],
+ "0b0_0_0": ["BadNumber", eol],
+ "0b0111012": ["BadNumber", eol],
+ "0b011101x": ["BadWord", eol],
+ "0b011101b": ["BadNumber", eol],
+ "0B0": ["BadNumber", eol],
+ "0X0": ["BadNumber", eol],
+ "0b_": ["BadNumber", eol],
+ "0x_": ["BadNumber", eol],
+ "0b__": ["BadNumber", eol],
+ "0x__": ["BadNumber", eol],
+ "0b_0000": ["Number", eol],
+ "0b0000_": ["BadNumber", eol],
+ "0b00_____00": ["BadNumber", eol],
+ "0x00_000_00": ["BadNumber", eol],
+ "0x0_0_0": ["BadNumber", eol],
+ "0b____0____": ["BadNumber", eol],
+ "0b00000000000000000000": ["Number", eol],
+ "0b_00000000": ["Number", eol],
+ "0b0000_0000_0000": ["Number", eol],
+ "0b000_0000_0000": ["Number", eol],
+ "0b00_0000_0000": ["Number", eol],
+ "0b0_0000_0000": ["Number", eol],
+ "0b_0000_0000_0000": ["Number", eol],
+ "0b_000_0000_0000": ["Number", eol],
+ "0b_00_0000_0000": ["Number", eol],
+ "0b_0_0000_0000": ["Number", eol],
+ "0b00000000_00000000_00000000": ["Number", eol],
+ "0b0000000_00000000_00000000": ["Number", eol],
+ "0b000000_00000000_00000000": ["Number", eol],
+ "0b00000_00000000_00000000": ["Number", eol],
+ "0b0000_00000000_00000000": ["Number", eol],
+ "0b000_00000000_00000000": ["Number", eol],
+ "0b00_00000000_00000000": ["Number", eol],
+ "0b0_00000000_00000000": ["Number", eol],
+ "0b_00000000_00000000_00000000": ["Number", eol],
+ "0b_0000000_00000000_00000000": ["Number", eol],
+ "0b_000000_00000000_00000000": ["Number", eol],
+ "0b_00000_00000000_00000000": ["Number", eol],
+ "0b_0000_00000000_00000000": ["Number", eol],
+ "0b_000_00000000_00000000": ["Number", eol],
+ "0b_00_00000000_00000000": ["Number", eol],
+ "0b_0_00000000_00000000": ["Number", eol],
+ "0x0": ["Number", eol],
+ "0x00000000000000000000": ["Number", eol],
+ "0x_0000": ["Number", eol],
+ "0x_00000000": ["Number", eol],
+ "0x0000_0000_0000": ["Number", eol],
+ "0x000_0000_0000": ["Number", eol],
+ "0x00_0000_0000": ["Number", eol],
+ "0x0_0000_0000": ["Number", eol],
+ "0x_0000_0000_0000": ["Number", eol],
+ "0x_000_0000_0000": ["Number", eol],
+ "0x_00_0000_0000": ["Number", eol],
+ "0x_0_0000_0000": ["Number", eol],
+ "0x00000000_00000000_00000000": ["Number", eol],
+ "0x0000000_00000000_00000000": ["Number", eol],
+ "0x000000_00000000_00000000": ["Number", eol],
+ "0x00000_00000000_00000000": ["Number", eol],
+ "0x0000_00000000_00000000": ["Number", eol],
+ "0x000_00000000_00000000": ["Number", eol],
+ "0x00_00000000_00000000": ["Number", eol],
+ "0x0_00000000_00000000": ["Number", eol],
+ "0x_00000000_00000000_00000000": ["Number", eol],
+ "0x_0000000_00000000_00000000": ["Number", eol],
+ "0x_000000_00000000_00000000": ["Number", eol],
+ "0x_00000_00000000_00000000": ["Number", eol],
+ "0x_0000_00000000_00000000": ["Number", eol],
+ "0x_000_00000000_00000000": ["Number", eol],
+ "0x_00_00000000_00000000": ["Number", eol],
+ "0x_0_00000000_00000000": ["Number", eol],
+ "0x__00000000_00000000": ["BadNumber", eol],
+ "0x00000000_00000000_0000": ["BadNumber", eol],
+ "0x00000000_0000_0000": ["BadNumber", eol],
+ "0x_00000000000000000000": ["BadNumber", eol],
+ "0b_00000000000000000000": ["BadNumber", eol],
+ "0b00000000_00000000_0000": ["BadNumber", eol],
+ "0b00000000_0000_0000": ["BadNumber", eol],
+ "0x0000_": ["BadNumber", eol],
+ "0x00_____00": ["BadNumber", eol],
+ "0x____0____": ["BadNumber", eol],
+ "EmbossReserved": ["BadWord", eol],
+ "EmbossReservedA": ["BadWord", eol],
+ "EmbossReserved_": ["BadWord", eol],
+ "EMBOSS_RESERVED": ["BadWord", eol],
+ "EMBOSS_RESERVED_": ["BadWord", eol],
+ "EMBOSS_RESERVEDA": ["BadWord", eol],
+ "emboss_reserved": ["BadWord", eol],
+ "emboss_reserved_": ["BadWord", eol],
+ "emboss_reserveda": ["BadWord", eol],
+ "0x0123456789abcdefABCDEF": ["Number", eol],
+ "0": ["Number", eol],
+ "1": ["Number", eol],
+ "1a": ["BadNumber", eol],
+ "1g": ["BadWord", eol],
+ "1234567890": ["Number", eol],
+ "1_234_567_890": ["Number", eol],
+ "234_567_890": ["Number", eol],
+ "34_567_890": ["Number", eol],
+ "4_567_890": ["Number", eol],
+ "1_2_3_4_5_6_7_8_9_0": ["BadNumber", eol],
+ "1234567890_": ["BadNumber", eol],
+ "1__234567890": ["BadNumber", eol],
+ "_1234567890": ["BadWord", eol],
+ "[]": ['"["', '"]"', eol],
+ "()": ['"("', '")"', eol],
+ "..": ['"."', '"."', eol],
+ "...": ['"."', '"."', '"."', eol],
+ "....": ['"."', '"."', '"."', '"."', eol],
+ '"abc"': ["String", eol],
+ '""': ["String", eol],
+ r'"\\"': ["String", eol],
+ r'"\""': ["String", eol],
+ r'"\n"': ["String", eol],
+ r'"\\n"': ["String", eol],
+ r'"\\xyz"': ["String", eol],
+ r'"\\\\"': ["String", eol],
+ }
+ for c in (
+ "[ ] ( ) ? : = + - * . == != < <= > >= && || , $max $present "
+ "$upper_bound $lower_bound $size_in_bits $size_in_bytes "
+ "$max_size_in_bits $max_size_in_bytes $min_size_in_bits "
+ "$min_size_in_bytes "
+ "$default struct bits enum external import as if let"
+ ).split():
+ cases[c] = ['"' + c + '"', eol]
- def make_test_case(case):
+ def make_test_case(case):
- def test_case(self):
- tokens, errors = tokenizer.tokenize(case, "name")
- symbols = _token_symbols(tokens)
- self.assertFalse(errors)
- self.assertEqual(symbols, cases[case])
+ def test_case(self):
+ tokens, errors = tokenizer.tokenize(case, "name")
+ symbols = _token_symbols(tokens)
+ self.assertFalse(errors)
+ self.assertEqual(symbols, cases[case])
- return test_case
+ return test_case
- for c in cases:
- setattr(TokenizerTest, "testShortTokenMatch{!r}".format(c),
- make_test_case(c))
+ for c in cases:
+ setattr(TokenizerTest, "testShortTokenMatch{!r}".format(c), make_test_case(c))
def _make_bad_char_tests():
- """Makes tests that an error is returned for bad characters."""
+ """Makes tests that an error is returned for bad characters."""
- def make_test_case(case):
+ def make_test_case(case):
- def test_case(self):
- tokens, errors = tokenizer.tokenize(case, "name")
- self.assertFalse(tokens)
- self.assertEqual([[error.error("name", parser_types.make_location(
- (1, 1), (1, 2)), "Unrecognized token")]], errors)
+ def test_case(self):
+ tokens, errors = tokenizer.tokenize(case, "name")
+ self.assertFalse(tokens)
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "name",
+ parser_types.make_location((1, 1), (1, 2)),
+ "Unrecognized token",
+ )
+ ]
+ ],
+ errors,
+ )
- return test_case
+ return test_case
- for c in "~`!@%^&\\|;'\"/{}":
- setattr(TokenizerTest, "testBadChar{!r}".format(c), make_test_case(c))
+ for c in "~`!@%^&\\|;'\"/{}":
+ setattr(TokenizerTest, "testBadChar{!r}".format(c), make_test_case(c))
def _make_bad_string_tests():
- """Makes tests that an error is returned for bad strings."""
- bad_strings = (r'"\"', '"\\\n"', r'"\\\"', r'"', r'"\q"', r'"\\\q"')
+ """Makes tests that an error is returned for bad strings."""
+ bad_strings = (r'"\"', '"\\\n"', r'"\\\"', r'"', r'"\q"', r'"\\\q"')
- def make_test_case(string):
+ def make_test_case(string):
- def test_case(self):
- tokens, errors = tokenizer.tokenize(string, "name")
- self.assertFalse(tokens)
- self.assertEqual([[error.error("name", parser_types.make_location(
- (1, 1), (1, 2)), "Unrecognized token")]], errors)
+ def test_case(self):
+ tokens, errors = tokenizer.tokenize(string, "name")
+ self.assertFalse(tokens)
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "name",
+ parser_types.make_location((1, 1), (1, 2)),
+ "Unrecognized token",
+ )
+ ]
+ ],
+ errors,
+ )
- return test_case
+ return test_case
- for s in bad_strings:
- setattr(TokenizerTest, "testBadString{!r}".format(s), make_test_case(s))
+ for s in bad_strings:
+ setattr(TokenizerTest, "testBadString{!r}".format(s), make_test_case(s))
def _make_multiline_tests():
- """Makes tests for indent/dedent insertion and eol insertion."""
+ """Makes tests for indent/dedent insertion and eol insertion."""
- c = "Comment"
- eol = '"\\n"'
- sw = "SnakeWord"
- ind = "Indent"
- ded = "Dedent"
- cases = {
- "a\nb\n": [sw, eol, sw, eol],
- "a\n\nb\n": [sw, eol, eol, sw, eol],
- "a\n#foo\nb\n": [sw, eol, c, eol, sw, eol],
- "a\n #foo\nb\n": [sw, eol, c, eol, sw, eol],
- "a\n b\n": [sw, eol, ind, sw, eol, ded],
- "a\n b\n\n": [sw, eol, ind, sw, eol, eol, ded],
- "a\n b\n c\n": [sw, eol, ind, sw, eol, ind, sw, eol, ded, ded],
- "a\n b\n c\n": [sw, eol, ind, sw, eol, sw, eol, ded],
- "a\n b\n\n c\n": [sw, eol, ind, sw, eol, eol, sw, eol, ded],
- "a\n b\n #\n c\n": [sw, eol, ind, sw, eol, c, eol, sw, eol, ded],
- "a\n\tb\n #\n\tc\n": [sw, eol, ind, sw, eol, c, eol, sw, eol, ded],
- " a\n b\n c\n d\n": [ind, sw, eol, ind, sw, eol, ind, sw, eol, ded,
- ded, sw, eol, ded],
- }
+ c = "Comment"
+ eol = '"\\n"'
+ sw = "SnakeWord"
+ ind = "Indent"
+ ded = "Dedent"
+ cases = {
+ "a\nb\n": [sw, eol, sw, eol],
+ "a\n\nb\n": [sw, eol, eol, sw, eol],
+ "a\n#foo\nb\n": [sw, eol, c, eol, sw, eol],
+ "a\n #foo\nb\n": [sw, eol, c, eol, sw, eol],
+ "a\n b\n": [sw, eol, ind, sw, eol, ded],
+ "a\n b\n\n": [sw, eol, ind, sw, eol, eol, ded],
+ "a\n b\n c\n": [sw, eol, ind, sw, eol, ind, sw, eol, ded, ded],
+ "a\n b\n c\n": [sw, eol, ind, sw, eol, sw, eol, ded],
+ "a\n b\n\n c\n": [sw, eol, ind, sw, eol, eol, sw, eol, ded],
+ "a\n b\n #\n c\n": [sw, eol, ind, sw, eol, c, eol, sw, eol, ded],
+ "a\n\tb\n #\n\tc\n": [sw, eol, ind, sw, eol, c, eol, sw, eol, ded],
+ " a\n b\n c\n d\n": [
+ ind,
+ sw,
+ eol,
+ ind,
+ sw,
+ eol,
+ ind,
+ sw,
+ eol,
+ ded,
+ ded,
+ sw,
+ eol,
+ ded,
+ ],
+ }
- def make_test_case(case):
+ def make_test_case(case):
- def test_case(self):
- tokens, errors = tokenizer.tokenize(case, "file")
- self.assertFalse(errors)
- self.assertEqual(_token_symbols(tokens), cases[case])
+ def test_case(self):
+ tokens, errors = tokenizer.tokenize(case, "file")
+ self.assertFalse(errors)
+ self.assertEqual(_token_symbols(tokens), cases[case])
- return test_case
+ return test_case
- for c in cases:
- setattr(TokenizerTest, "testMultiline{!r}".format(c), make_test_case(c))
+ for c in cases:
+ setattr(TokenizerTest, "testMultiline{!r}".format(c), make_test_case(c))
def _make_offset_tests():
- """Makes tests that the tokenizer fills in correct source locations."""
- cases = {
- "a+": ["1:1-1:2", "1:2-1:3", "1:3-1:3"],
- "a + ": ["1:1-1:2", "1:5-1:6", "1:9-1:9"],
- "a\n\nb": ["1:1-1:2", "1:2-1:2", "2:1-2:1", "3:1-3:2", "3:2-3:2"],
- "a\n b": ["1:1-1:2", "1:2-1:2", "2:1-2:3", "2:3-2:4", "2:4-2:4",
- "3:1-3:1"],
- "a\n b\nc": ["1:1-1:2", "1:2-1:2", "2:1-2:3", "2:3-2:4", "2:4-2:4",
- "3:1-3:1", "3:1-3:2", "3:2-3:2"],
- "a\n b\n c": ["1:1-1:2", "1:2-1:2", "2:1-2:2", "2:2-2:3", "2:3-2:3",
- "3:2-3:3", "3:3-3:4", "3:4-3:4", "4:1-4:1", "4:1-4:1"],
- }
+ """Makes tests that the tokenizer fills in correct source locations."""
+ cases = {
+ "a+": ["1:1-1:2", "1:2-1:3", "1:3-1:3"],
+ "a + ": ["1:1-1:2", "1:5-1:6", "1:9-1:9"],
+ "a\n\nb": ["1:1-1:2", "1:2-1:2", "2:1-2:1", "3:1-3:2", "3:2-3:2"],
+ "a\n b": ["1:1-1:2", "1:2-1:2", "2:1-2:3", "2:3-2:4", "2:4-2:4", "3:1-3:1"],
+ "a\n b\nc": [
+ "1:1-1:2",
+ "1:2-1:2",
+ "2:1-2:3",
+ "2:3-2:4",
+ "2:4-2:4",
+ "3:1-3:1",
+ "3:1-3:2",
+ "3:2-3:2",
+ ],
+ "a\n b\n c": [
+ "1:1-1:2",
+ "1:2-1:2",
+ "2:1-2:2",
+ "2:2-2:3",
+ "2:3-2:3",
+ "3:2-3:3",
+ "3:3-3:4",
+ "3:4-3:4",
+ "4:1-4:1",
+ "4:1-4:1",
+ ],
+ }
- def make_test_case(case):
+ def make_test_case(case):
- def test_case(self):
- self.assertEqual([parser_types.format_location(l.source_location)
- for l in tokenizer.tokenize(case, "file")[0]],
- cases[case])
+ def test_case(self):
+ self.assertEqual(
+ [
+ parser_types.format_location(l.source_location)
+ for l in tokenizer.tokenize(case, "file")[0]
+ ],
+ cases[case],
+ )
- return test_case
+ return test_case
- for c in cases:
- setattr(TokenizerTest, "testOffset{!r}".format(c), make_test_case(c))
+ for c in cases:
+ setattr(TokenizerTest, "testOffset{!r}".format(c), make_test_case(c))
+
_make_short_token_match_tests()
_make_bad_char_tests()
@@ -385,4 +502,4 @@
_make_offset_tests()
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py
index 9ad4aee..f562cc8 100644
--- a/compiler/front_end/type_check.py
+++ b/compiler/front_end/type_check.py
@@ -23,461 +23,633 @@
def _type_check_expression(expression, source_file_name, ir, errors):
- """Checks and annotates the type of an expression and all subexpressions."""
- if ir_data_utils.reader(expression).type.WhichOneof("type"):
- # This expression has already been type checked.
- return
- expression_variety = expression.WhichOneof("expression")
- if expression_variety == "constant":
- _type_check_integer_constant(expression)
- elif expression_variety == "constant_reference":
- _type_check_constant_reference(expression, source_file_name, ir, errors)
- elif expression_variety == "function":
- _type_check_operation(expression, source_file_name, ir, errors)
- elif expression_variety == "field_reference":
- _type_check_local_reference(expression, ir, errors)
- elif expression_variety == "boolean_constant":
- _type_check_boolean_constant(expression)
- elif expression_variety == "builtin_reference":
- _type_check_builtin_reference(expression)
- else:
- assert False, "Unknown expression variety {!r}".format(expression_variety)
+ """Checks and annotates the type of an expression and all subexpressions."""
+ if ir_data_utils.reader(expression).type.WhichOneof("type"):
+ # This expression has already been type checked.
+ return
+ expression_variety = expression.WhichOneof("expression")
+ if expression_variety == "constant":
+ _type_check_integer_constant(expression)
+ elif expression_variety == "constant_reference":
+ _type_check_constant_reference(expression, source_file_name, ir, errors)
+ elif expression_variety == "function":
+ _type_check_operation(expression, source_file_name, ir, errors)
+ elif expression_variety == "field_reference":
+ _type_check_local_reference(expression, ir, errors)
+ elif expression_variety == "boolean_constant":
+ _type_check_boolean_constant(expression)
+ elif expression_variety == "builtin_reference":
+ _type_check_builtin_reference(expression)
+ else:
+ assert False, "Unknown expression variety {!r}".format(expression_variety)
def _annotate_as_integer(expression):
- ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType())
+ ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType())
def _annotate_as_boolean(expression):
- ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType())
+ ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType())
-def _type_check(expression, source_file_name, errors, type_oneof, type_name,
- expression_name):
- if ir_data_utils.reader(expression).type.WhichOneof("type") != type_oneof:
- errors.append([
- error.error(source_file_name, expression.source_location,
- "{} must be {}.".format(expression_name, type_name))
- ])
+def _type_check(
+ expression, source_file_name, errors, type_oneof, type_name, expression_name
+):
+ if ir_data_utils.reader(expression).type.WhichOneof("type") != type_oneof:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ expression.source_location,
+ "{} must be {}.".format(expression_name, type_name),
+ )
+ ]
+ )
def _type_check_integer(expression, source_file_name, errors, expression_name):
- _type_check(expression, source_file_name, errors, "integer",
- "an integer", expression_name)
+ _type_check(
+ expression, source_file_name, errors, "integer", "an integer", expression_name
+ )
def _type_check_boolean(expression, source_file_name, errors, expression_name):
- _type_check(expression, source_file_name, errors, "boolean", "a boolean",
- expression_name)
+ _type_check(
+ expression, source_file_name, errors, "boolean", "a boolean", expression_name
+ )
-def _kind_check_field_reference(expression, source_file_name, errors,
- expression_name):
- if expression.WhichOneof("expression") != "field_reference":
- errors.append([
- error.error(source_file_name, expression.source_location,
- "{} must be a field.".format(expression_name))
- ])
+def _kind_check_field_reference(expression, source_file_name, errors, expression_name):
+ if expression.WhichOneof("expression") != "field_reference":
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ expression.source_location,
+ "{} must be a field.".format(expression_name),
+ )
+ ]
+ )
def _type_check_integer_constant(expression):
- _annotate_as_integer(expression)
+ _annotate_as_integer(expression)
def _type_check_constant_reference(expression, source_file_name, ir, errors):
- """Annotates the type of a constant reference."""
- referred_name = expression.constant_reference.canonical_name
- referred_object = ir_util.find_object(referred_name, ir)
- if isinstance(referred_object, ir_data.EnumValue):
- ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(expression.constant_reference)
- del expression.type.enumeration.name.canonical_name.object_path[-1]
- elif isinstance(referred_object, ir_data.Field):
- if not ir_util.field_is_virtual(referred_object):
- errors.append([
- error.error(source_file_name, expression.source_location,
- "Static references to physical fields are not allowed."),
- error.note(referred_name.module_file, referred_object.source_location,
- "{} is a physical field.".format(
- referred_name.object_path[-1])),
- ])
- return
- _type_check_expression(referred_object.read_transform,
- referred_name.module_file, ir, errors)
- ir_data_utils.builder(expression).type.CopyFrom(referred_object.read_transform.type)
- else:
- assert False, "Unexpected constant reference type."
+ """Annotates the type of a constant reference."""
+ referred_name = expression.constant_reference.canonical_name
+ referred_object = ir_util.find_object(referred_name, ir)
+ if isinstance(referred_object, ir_data.EnumValue):
+ ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(
+ expression.constant_reference
+ )
+ del expression.type.enumeration.name.canonical_name.object_path[-1]
+ elif isinstance(referred_object, ir_data.Field):
+ if not ir_util.field_is_virtual(referred_object):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ expression.source_location,
+ "Static references to physical fields are not allowed.",
+ ),
+ error.note(
+ referred_name.module_file,
+ referred_object.source_location,
+ "{} is a physical field.".format(referred_name.object_path[-1]),
+ ),
+ ]
+ )
+ return
+ _type_check_expression(
+ referred_object.read_transform, referred_name.module_file, ir, errors
+ )
+ ir_data_utils.builder(expression).type.CopyFrom(
+ referred_object.read_transform.type
+ )
+ else:
+ assert False, "Unexpected constant reference type."
def _type_check_operation(expression, source_file_name, ir, errors):
- for arg in expression.function.args:
- _type_check_expression(arg, source_file_name, ir, errors)
- function = expression.function.function
- if function in (ir_data.FunctionMapping.EQUALITY, ir_data.FunctionMapping.INEQUALITY,
- ir_data.FunctionMapping.LESS, ir_data.FunctionMapping.LESS_OR_EQUAL,
- ir_data.FunctionMapping.GREATER, ir_data.FunctionMapping.GREATER_OR_EQUAL):
- _type_check_comparison_operator(expression, source_file_name, errors)
- elif function == ir_data.FunctionMapping.CHOICE:
- _type_check_choice_operator(expression, source_file_name, errors)
- else:
- _type_check_monomorphic_operator(expression, source_file_name, errors)
+ for arg in expression.function.args:
+ _type_check_expression(arg, source_file_name, ir, errors)
+ function = expression.function.function
+ if function in (
+ ir_data.FunctionMapping.EQUALITY,
+ ir_data.FunctionMapping.INEQUALITY,
+ ir_data.FunctionMapping.LESS,
+ ir_data.FunctionMapping.LESS_OR_EQUAL,
+ ir_data.FunctionMapping.GREATER,
+ ir_data.FunctionMapping.GREATER_OR_EQUAL,
+ ):
+ _type_check_comparison_operator(expression, source_file_name, errors)
+ elif function == ir_data.FunctionMapping.CHOICE:
+ _type_check_choice_operator(expression, source_file_name, errors)
+ else:
+ _type_check_monomorphic_operator(expression, source_file_name, errors)
def _type_check_monomorphic_operator(expression, source_file_name, errors):
- """Type checks an operator that accepts only one set of argument types."""
- args = expression.function.args
- int_args = _type_check_integer
- bool_args = _type_check_boolean
- field_args = _kind_check_field_reference
- int_result = _annotate_as_integer
- bool_result = _annotate_as_boolean
- binary = ("Left argument", "Right argument")
- n_ary = ("Argument {}".format(n) for n in range(len(args)))
- functions = {
- ir_data.FunctionMapping.ADDITION: (int_result, int_args, binary, 2, 2,
- "operator"),
- ir_data.FunctionMapping.SUBTRACTION: (int_result, int_args, binary, 2, 2,
- "operator"),
- ir_data.FunctionMapping.MULTIPLICATION: (int_result, int_args, binary, 2, 2,
- "operator"),
- ir_data.FunctionMapping.AND: (bool_result, bool_args, binary, 2, 2, "operator"),
- ir_data.FunctionMapping.OR: (bool_result, bool_args, binary, 2, 2, "operator"),
- ir_data.FunctionMapping.MAXIMUM: (int_result, int_args, n_ary, 1, None,
- "function"),
- ir_data.FunctionMapping.PRESENCE: (bool_result, field_args, n_ary, 1, 1,
- "function"),
- ir_data.FunctionMapping.UPPER_BOUND: (int_result, int_args, n_ary, 1, 1,
- "function"),
- ir_data.FunctionMapping.LOWER_BOUND: (int_result, int_args, n_ary, 1, 1,
- "function"),
- }
- function = expression.function.function
- (set_result_type, check_arg, arg_names, min_args, max_args,
- kind) = functions[function]
- for argument, name in zip(args, arg_names):
- assert name is not None, "Too many arguments to function!"
- check_arg(argument, source_file_name, errors,
- "{} of {} '{}'".format(name, kind,
- expression.function.function_name.text))
- if len(args) < min_args:
- errors.append([
- error.error(source_file_name, expression.source_location,
+ """Type checks an operator that accepts only one set of argument types."""
+ args = expression.function.args
+ int_args = _type_check_integer
+ bool_args = _type_check_boolean
+ field_args = _kind_check_field_reference
+ int_result = _annotate_as_integer
+ bool_result = _annotate_as_boolean
+ binary = ("Left argument", "Right argument")
+ n_ary = ("Argument {}".format(n) for n in range(len(args)))
+ functions = {
+ ir_data.FunctionMapping.ADDITION: (
+ int_result,
+ int_args,
+ binary,
+ 2,
+ 2,
+ "operator",
+ ),
+ ir_data.FunctionMapping.SUBTRACTION: (
+ int_result,
+ int_args,
+ binary,
+ 2,
+ 2,
+ "operator",
+ ),
+ ir_data.FunctionMapping.MULTIPLICATION: (
+ int_result,
+ int_args,
+ binary,
+ 2,
+ 2,
+ "operator",
+ ),
+ ir_data.FunctionMapping.AND: (bool_result, bool_args, binary, 2, 2, "operator"),
+ ir_data.FunctionMapping.OR: (bool_result, bool_args, binary, 2, 2, "operator"),
+ ir_data.FunctionMapping.MAXIMUM: (
+ int_result,
+ int_args,
+ n_ary,
+ 1,
+ None,
+ "function",
+ ),
+ ir_data.FunctionMapping.PRESENCE: (
+ bool_result,
+ field_args,
+ n_ary,
+ 1,
+ 1,
+ "function",
+ ),
+ ir_data.FunctionMapping.UPPER_BOUND: (
+ int_result,
+ int_args,
+ n_ary,
+ 1,
+ 1,
+ "function",
+ ),
+ ir_data.FunctionMapping.LOWER_BOUND: (
+ int_result,
+ int_args,
+ n_ary,
+ 1,
+ 1,
+ "function",
+ ),
+ }
+ function = expression.function.function
+ (set_result_type, check_arg, arg_names, min_args, max_args, kind) = functions[
+ function
+ ]
+ for argument, name in zip(args, arg_names):
+ assert name is not None, "Too many arguments to function!"
+ check_arg(
+ argument,
+ source_file_name,
+ errors,
+ "{} of {} '{}'".format(name, kind, expression.function.function_name.text),
+ )
+ if len(args) < min_args:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ expression.source_location,
"{} '{}' requires {} {} argument{}.".format(
- kind.title(), expression.function.function_name.text,
+ kind.title(),
+ expression.function.function_name.text,
"exactly" if min_args == max_args else "at least",
- min_args, "s" if min_args > 1 else ""))
- ])
- if max_args is not None and len(args) > max_args:
- errors.append([
- error.error(source_file_name, expression.source_location,
+ min_args,
+ "s" if min_args > 1 else "",
+ ),
+ )
+ ]
+ )
+ if max_args is not None and len(args) > max_args:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ expression.source_location,
"{} '{}' requires {} {} argument{}.".format(
- kind.title(), expression.function.function_name.text,
+ kind.title(),
+ expression.function.function_name.text,
"exactly" if min_args == max_args else "at most",
- max_args, "s" if max_args > 1 else ""))
- ])
- set_result_type(expression)
+ max_args,
+ "s" if max_args > 1 else "",
+ ),
+ )
+ ]
+ )
+ set_result_type(expression)
def _type_check_local_reference(expression, ir, errors):
- """Annotates the type of a local reference."""
- referrent = ir_util.find_object(expression.field_reference.path[-1], ir)
- assert referrent, "Local reference should be non-None after name resolution."
- if isinstance(referrent, ir_data.RuntimeParameter):
- parameter = referrent
- _set_expression_type_from_physical_type_reference(
- expression, parameter.physical_type_alias.atomic_type.reference, ir)
- return
- field = referrent
- if ir_util.field_is_virtual(field):
- _type_check_expression(field.read_transform,
- expression.field_reference.path[0], ir, errors)
- ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
- return
- if not field.type.HasField("atomic_type"):
- ir_data_utils.builder(expression).type.opaque.CopyFrom(ir_data.OpaqueType())
- else:
- _set_expression_type_from_physical_type_reference(
- expression, field.type.atomic_type.reference, ir)
+ """Annotates the type of a local reference."""
+ referrent = ir_util.find_object(expression.field_reference.path[-1], ir)
+ assert referrent, "Local reference should be non-None after name resolution."
+ if isinstance(referrent, ir_data.RuntimeParameter):
+ parameter = referrent
+ _set_expression_type_from_physical_type_reference(
+ expression, parameter.physical_type_alias.atomic_type.reference, ir
+ )
+ return
+ field = referrent
+ if ir_util.field_is_virtual(field):
+ _type_check_expression(
+ field.read_transform, expression.field_reference.path[0], ir, errors
+ )
+ ir_data_utils.builder(expression).type.CopyFrom(field.read_transform.type)
+ return
+ if not field.type.HasField("atomic_type"):
+ ir_data_utils.builder(expression).type.opaque.CopyFrom(ir_data.OpaqueType())
+ else:
+ _set_expression_type_from_physical_type_reference(
+ expression, field.type.atomic_type.reference, ir
+ )
def unbounded_expression_type_for_physical_type(type_definition):
- """Gets the ExpressionType for a field of the given TypeDefinition.
+ """Gets the ExpressionType for a field of the given TypeDefinition.
- Arguments:
- type_definition: an ir_data.AddressableUnit.
+ Arguments:
+ type_definition: an ir_data.AddressableUnit.
- Returns:
- An ir_data.ExpressionType with the corresponding expression type filled in:
- for example, [prelude].UInt will result in an ExpressionType with the
- `integer` field filled in.
+ Returns:
+ An ir_data.ExpressionType with the corresponding expression type filled in:
+ for example, [prelude].UInt will result in an ExpressionType with the
+ `integer` field filled in.
- The returned ExpressionType will not have any bounds set.
- """
- # TODO(bolms): Add a `[value_type]` attribute for `external`s.
- if ir_util.get_boolean_attribute(type_definition.attribute,
- attributes.IS_INTEGER):
- return ir_data.ExpressionType(integer=ir_data.IntegerType())
- elif tuple(type_definition.name.canonical_name.object_path) == ("Flag",):
- # This is a hack: the Flag type should say that it is a boolean.
- return ir_data.ExpressionType(boolean=ir_data.BooleanType())
- elif type_definition.HasField("enumeration"):
- return ir_data.ExpressionType(
- enumeration=ir_data.EnumType(
- name=ir_data.Reference(
- canonical_name=type_definition.name.canonical_name)))
- else:
- return ir_data.ExpressionType(opaque=ir_data.OpaqueType())
+ The returned ExpressionType will not have any bounds set.
+ """
+ # TODO(bolms): Add a `[value_type]` attribute for `external`s.
+ if ir_util.get_boolean_attribute(type_definition.attribute, attributes.IS_INTEGER):
+ return ir_data.ExpressionType(integer=ir_data.IntegerType())
+ elif tuple(type_definition.name.canonical_name.object_path) == ("Flag",):
+ # This is a hack: the Flag type should say that it is a boolean.
+ return ir_data.ExpressionType(boolean=ir_data.BooleanType())
+ elif type_definition.HasField("enumeration"):
+ return ir_data.ExpressionType(
+ enumeration=ir_data.EnumType(
+ name=ir_data.Reference(
+ canonical_name=type_definition.name.canonical_name
+ )
+ )
+ )
+ else:
+ return ir_data.ExpressionType(opaque=ir_data.OpaqueType())
-def _set_expression_type_from_physical_type_reference(expression,
- type_reference, ir):
- """Sets the type of an expression to match a physical type."""
- field_type = ir_util.find_object(type_reference, ir)
- assert field_type, "Field type should be non-None after name resolution."
- ir_data_utils.builder(expression).type.CopyFrom(
- unbounded_expression_type_for_physical_type(field_type))
+def _set_expression_type_from_physical_type_reference(expression, type_reference, ir):
+ """Sets the type of an expression to match a physical type."""
+ field_type = ir_util.find_object(type_reference, ir)
+ assert field_type, "Field type should be non-None after name resolution."
+ ir_data_utils.builder(expression).type.CopyFrom(
+ unbounded_expression_type_for_physical_type(field_type)
+ )
def _annotate_parameter_type(parameter, ir, source_file_name, errors):
- if parameter.physical_type_alias.WhichOneof("type") != "atomic_type":
- errors.append([
- error.error(
- source_file_name, parameter.physical_type_alias.source_location,
- "Parameters cannot be arrays.")
- ])
- return
- _set_expression_type_from_physical_type_reference(
- parameter, parameter.physical_type_alias.atomic_type.reference, ir)
+ if parameter.physical_type_alias.WhichOneof("type") != "atomic_type":
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ parameter.physical_type_alias.source_location,
+ "Parameters cannot be arrays.",
+ )
+ ]
+ )
+ return
+ _set_expression_type_from_physical_type_reference(
+ parameter, parameter.physical_type_alias.atomic_type.reference, ir
+ )
def _types_are_compatible(a, b):
- """Returns true if a and b have compatible types."""
- if a.type.WhichOneof("type") != b.type.WhichOneof("type"):
- return False
- elif a.type.WhichOneof("type") == "enumeration":
- return (ir_util.hashable_form_of_reference(a.type.enumeration.name) ==
- ir_util.hashable_form_of_reference(b.type.enumeration.name))
- elif a.type.WhichOneof("type") in ("integer", "boolean"):
- # All integers are compatible with integers; booleans are compatible with
- # booleans
- return True
- else:
- assert False, "_types_are_compatible works with enums, integers, booleans."
+ """Returns true if a and b have compatible types."""
+ if a.type.WhichOneof("type") != b.type.WhichOneof("type"):
+ return False
+ elif a.type.WhichOneof("type") == "enumeration":
+ return ir_util.hashable_form_of_reference(
+ a.type.enumeration.name
+ ) == ir_util.hashable_form_of_reference(b.type.enumeration.name)
+ elif a.type.WhichOneof("type") in ("integer", "boolean"):
+ # All integers are compatible with integers; booleans are compatible with
+ # booleans
+ return True
+ else:
+ assert False, "_types_are_compatible works with enums, integers, booleans."
def _type_check_comparison_operator(expression, source_file_name, errors):
- """Checks the type of a comparison operator (==, !=, <, >, >=, <=)."""
- # Applying less than or greater than to a boolean is likely a mistake, so
- # only equality and inequality are allowed for booleans.
- if expression.function.function in (ir_data.FunctionMapping.EQUALITY,
- ir_data.FunctionMapping.INEQUALITY):
- acceptable_types = ("integer", "boolean", "enumeration")
- acceptable_types_for_humans = "an integer, boolean, or enum"
- else:
- acceptable_types = ("integer", "enumeration")
- acceptable_types_for_humans = "an integer or enum"
- left = expression.function.args[0]
- right = expression.function.args[1]
- for (argument, name) in ((left, "Left"), (right, "Right")):
- if argument.type.WhichOneof("type") not in acceptable_types:
- errors.append([
- error.error(source_file_name, argument.source_location,
- "{} argument of operator '{}' must be {}.".format(
- name, expression.function.function_name.text,
- acceptable_types_for_humans))
- ])
- return
- if not _types_are_compatible(left, right):
- errors.append([
- error.error(source_file_name, expression.source_location,
+ """Checks the type of a comparison operator (==, !=, <, >, >=, <=)."""
+ # Applying less than or greater than to a boolean is likely a mistake, so
+ # only equality and inequality are allowed for booleans.
+ if expression.function.function in (
+ ir_data.FunctionMapping.EQUALITY,
+ ir_data.FunctionMapping.INEQUALITY,
+ ):
+ acceptable_types = ("integer", "boolean", "enumeration")
+ acceptable_types_for_humans = "an integer, boolean, or enum"
+ else:
+ acceptable_types = ("integer", "enumeration")
+ acceptable_types_for_humans = "an integer or enum"
+ left = expression.function.args[0]
+ right = expression.function.args[1]
+ for argument, name in ((left, "Left"), (right, "Right")):
+ if argument.type.WhichOneof("type") not in acceptable_types:
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ argument.source_location,
+ "{} argument of operator '{}' must be {}.".format(
+ name,
+ expression.function.function_name.text,
+ acceptable_types_for_humans,
+ ),
+ )
+ ]
+ )
+ return
+ if not _types_are_compatible(left, right):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ expression.source_location,
"Both arguments of operator '{}' must have the same "
- "type.".format(expression.function.function_name.text))
- ])
- _annotate_as_boolean(expression)
+ "type.".format(expression.function.function_name.text),
+ )
+ ]
+ )
+ _annotate_as_boolean(expression)
def _type_check_choice_operator(expression, source_file_name, errors):
- """Checks the type of the choice operator cond ? if_true : if_false."""
- condition = expression.function.args[0]
- if condition.type.WhichOneof("type") != "boolean":
- errors.append([
- error.error(source_file_name, condition.source_location,
- "Condition of operator '?:' must be a boolean.")
- ])
- if_true = expression.function.args[1]
- if if_true.type.WhichOneof("type") not in ("integer", "boolean",
- "enumeration"):
- errors.append([
- error.error(source_file_name, if_true.source_location,
+ """Checks the type of the choice operator cond ? if_true : if_false."""
+ condition = expression.function.args[0]
+ if condition.type.WhichOneof("type") != "boolean":
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ condition.source_location,
+ "Condition of operator '?:' must be a boolean.",
+ )
+ ]
+ )
+ if_true = expression.function.args[1]
+ if if_true.type.WhichOneof("type") not in ("integer", "boolean", "enumeration"):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ if_true.source_location,
"If-true clause of operator '?:' must be an integer, "
- "boolean, or enum.")
- ])
- return
- if_false = expression.function.args[2]
- if not _types_are_compatible(if_true, if_false):
- errors.append([
- error.error(source_file_name, expression.source_location,
+ "boolean, or enum.",
+ )
+ ]
+ )
+ return
+ if_false = expression.function.args[2]
+ if not _types_are_compatible(if_true, if_false):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ expression.source_location,
"The if-true and if-false clauses of operator '?:' must "
- "have the same type.")
- ])
- if if_true.type.WhichOneof("type") == "integer":
- _annotate_as_integer(expression)
- elif if_true.type.WhichOneof("type") == "boolean":
- _annotate_as_boolean(expression)
- elif if_true.type.WhichOneof("type") == "enumeration":
- ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(if_true.type.enumeration.name)
- else:
- assert False, "Unexpected type for if_true."
+ "have the same type.",
+ )
+ ]
+ )
+ if if_true.type.WhichOneof("type") == "integer":
+ _annotate_as_integer(expression)
+ elif if_true.type.WhichOneof("type") == "boolean":
+ _annotate_as_boolean(expression)
+ elif if_true.type.WhichOneof("type") == "enumeration":
+ ir_data_utils.builder(expression).type.enumeration.name.CopyFrom(
+ if_true.type.enumeration.name
+ )
+ else:
+ assert False, "Unexpected type for if_true."
def _type_check_boolean_constant(expression):
- _annotate_as_boolean(expression)
+ _annotate_as_boolean(expression)
def _type_check_builtin_reference(expression):
- name = expression.builtin_reference.canonical_name.object_path[0]
- if name == "$is_statically_sized":
- _annotate_as_boolean(expression)
- elif name == "$static_size_in_bits":
- _annotate_as_integer(expression)
- else:
- assert False, "Unknown builtin '{}'.".format(name)
+ name = expression.builtin_reference.canonical_name.object_path[0]
+ if name == "$is_statically_sized":
+ _annotate_as_boolean(expression)
+ elif name == "$static_size_in_bits":
+ _annotate_as_integer(expression)
+ else:
+ assert False, "Unknown builtin '{}'.".format(name)
def _type_check_array_size(expression, source_file_name, errors):
- _type_check_integer(expression, source_file_name, errors, "Array size")
+ _type_check_integer(expression, source_file_name, errors, "Array size")
def _type_check_field_location(location, source_file_name, errors):
- _type_check_integer(location.start, source_file_name, errors,
- "Start of field")
- _type_check_integer(location.size, source_file_name, errors, "Size of field")
+ _type_check_integer(location.start, source_file_name, errors, "Start of field")
+ _type_check_integer(location.size, source_file_name, errors, "Size of field")
def _type_check_field_existence_condition(field, source_file_name, errors):
- _type_check_boolean(field.existence_condition, source_file_name, errors,
- "Existence condition")
+ _type_check_boolean(
+ field.existence_condition, source_file_name, errors, "Existence condition"
+ )
def _type_name_for_error_messages(expression_type):
- if expression_type.WhichOneof("type") == "integer":
- return "integer"
- elif expression_type.WhichOneof("type") == "enumeration":
- # TODO(bolms): Should this be the fully-qualified name?
- return expression_type.enumeration.name.canonical_name.object_path[-1]
- assert False, "Shouldn't be here."
+ if expression_type.WhichOneof("type") == "integer":
+ return "integer"
+ elif expression_type.WhichOneof("type") == "enumeration":
+ # TODO(bolms): Should this be the fully-qualified name?
+ return expression_type.enumeration.name.canonical_name.object_path[-1]
+ assert False, "Shouldn't be here."
def _type_check_passed_parameters(atomic_type, ir, source_file_name, errors):
- """Checks the types of parameters to a parameterized physical type."""
- referenced_type = ir_util.find_object(atomic_type.reference.canonical_name,
- ir)
- if (len(referenced_type.runtime_parameter) !=
- len(atomic_type.runtime_parameter)):
- errors.append([
- error.error(
- source_file_name, atomic_type.source_location,
- "Type {} requires {} parameter{}; {} parameter{} given.".format(
- referenced_type.name.name.text,
- len(referenced_type.runtime_parameter),
- "" if len(referenced_type.runtime_parameter) == 1 else "s",
- len(atomic_type.runtime_parameter),
- "" if len(atomic_type.runtime_parameter) == 1 else "s")),
- error.note(
- atomic_type.reference.canonical_name.module_file,
- referenced_type.source_location,
- "Definition of type {}.".format(referenced_type.name.name.text))
- ])
- return
- for i in range(len(referenced_type.runtime_parameter)):
- if referenced_type.runtime_parameter[i].type.WhichOneof("type") not in (
- "integer", "boolean", "enumeration"):
- # _type_check_parameter will catch invalid parameter types at the
- # definition site; no need for another, probably-confusing error at any
- # usage sites.
- continue
- if (atomic_type.runtime_parameter[i].type.WhichOneof("type") !=
- referenced_type.runtime_parameter[i].type.WhichOneof("type")):
- errors.append([
- error.error(
- source_file_name,
- atomic_type.runtime_parameter[i].source_location,
- "Parameter {} of type {} must be {}, not {}.".format(
- i, referenced_type.name.name.text,
- _type_name_for_error_messages(
- referenced_type.runtime_parameter[i].type),
- _type_name_for_error_messages(
- atomic_type.runtime_parameter[i].type))),
- error.note(
- atomic_type.reference.canonical_name.module_file,
- referenced_type.runtime_parameter[i].source_location,
- "Parameter {} of {}.".format(i, referenced_type.name.name.text))
- ])
+ """Checks the types of parameters to a parameterized physical type."""
+ referenced_type = ir_util.find_object(atomic_type.reference.canonical_name, ir)
+ if len(referenced_type.runtime_parameter) != len(atomic_type.runtime_parameter):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ atomic_type.source_location,
+ "Type {} requires {} parameter{}; {} parameter{} given.".format(
+ referenced_type.name.name.text,
+ len(referenced_type.runtime_parameter),
+ "" if len(referenced_type.runtime_parameter) == 1 else "s",
+ len(atomic_type.runtime_parameter),
+ "" if len(atomic_type.runtime_parameter) == 1 else "s",
+ ),
+ ),
+ error.note(
+ atomic_type.reference.canonical_name.module_file,
+ referenced_type.source_location,
+ "Definition of type {}.".format(referenced_type.name.name.text),
+ ),
+ ]
+ )
+ return
+ for i in range(len(referenced_type.runtime_parameter)):
+ if referenced_type.runtime_parameter[i].type.WhichOneof("type") not in (
+ "integer",
+ "boolean",
+ "enumeration",
+ ):
+ # _type_check_parameter will catch invalid parameter types at the
+ # definition site; no need for another, probably-confusing error at any
+ # usage sites.
+ continue
+ if atomic_type.runtime_parameter[i].type.WhichOneof(
+ "type"
+ ) != referenced_type.runtime_parameter[i].type.WhichOneof("type"):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
+ atomic_type.runtime_parameter[i].source_location,
+ "Parameter {} of type {} must be {}, not {}.".format(
+ i,
+ referenced_type.name.name.text,
+ _type_name_for_error_messages(
+ referenced_type.runtime_parameter[i].type
+ ),
+ _type_name_for_error_messages(
+ atomic_type.runtime_parameter[i].type
+ ),
+ ),
+ ),
+ error.note(
+ atomic_type.reference.canonical_name.module_file,
+ referenced_type.runtime_parameter[i].source_location,
+ "Parameter {} of {}.".format(i, referenced_type.name.name.text),
+ ),
+ ]
+ )
def _type_check_parameter(runtime_parameter, source_file_name, errors):
- """Checks the type of a parameter to a physical type."""
- if runtime_parameter.type.WhichOneof("type") not in ("integer",
- "enumeration"):
- errors.append([
- error.error(source_file_name,
+ """Checks the type of a parameter to a physical type."""
+ if runtime_parameter.type.WhichOneof("type") not in ("integer", "enumeration"):
+ errors.append(
+ [
+ error.error(
+ source_file_name,
runtime_parameter.physical_type_alias.source_location,
- "Runtime parameters must be integer or enum.")
- ])
+ "Runtime parameters must be integer or enum.",
+ )
+ ]
+ )
def annotate_types(ir):
- """Adds type annotations to all expressions in ir.
+ """Adds type annotations to all expressions in ir.
- annotate_types adds type information to all expressions (and subexpressions)
- in the IR. Additionally, it checks expressions for internal type consistency:
- it will generate an error for constructs like "1 + true", where the types of
- the operands are not accepted by the operator.
+ annotate_types adds type information to all expressions (and subexpressions)
+ in the IR. Additionally, it checks expressions for internal type consistency:
+ it will generate an error for constructs like "1 + true", where the types of
+ the operands are not accepted by the operator.
- Arguments:
- ir: an IR to which to add type annotations
+ Arguments:
+ ir: an IR to which to add type annotations
- Returns:
- A (possibly empty) list of errors.
- """
- errors = []
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Expression], _type_check_expression,
- skip_descendants_of={ir_data.Expression},
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.RuntimeParameter], _annotate_parameter_type,
- parameters={"errors": errors})
- return errors
+ Returns:
+ A (possibly empty) list of errors.
+ """
+ errors = []
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Expression],
+ _type_check_expression,
+ skip_descendants_of={ir_data.Expression},
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.RuntimeParameter],
+ _annotate_parameter_type,
+ parameters={"errors": errors},
+ )
+ return errors
def check_types(ir):
- """Checks that expressions within the IR have the correct top-level types.
+ """Checks that expressions within the IR have the correct top-level types.
- check_types ensures that expressions at the top level have correct types; in
- particular, it ensures that array sizes are integers ("UInt[true]" is not a
- valid array type) and that the starts and ends of ranges are integers.
+ check_types ensures that expressions at the top level have correct types; in
+ particular, it ensures that array sizes are integers ("UInt[true]" is not a
+ valid array type) and that the starts and ends of ranges are integers.
- Arguments:
- ir: an IR to type check.
+ Arguments:
+ ir: an IR to type check.
- Returns:
- A (possibly empty) list of errors.
- """
- errors = []
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.FieldLocation], _type_check_field_location,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.ArrayType, ir_data.Expression], _type_check_array_size,
- skip_descendants_of={ir_data.AtomicType},
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Field], _type_check_field_existence_condition,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.RuntimeParameter], _type_check_parameter,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.AtomicType], _type_check_passed_parameters,
- parameters={"errors": errors})
- return errors
+ Returns:
+ A (possibly empty) list of errors.
+ """
+ errors = []
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.FieldLocation],
+ _type_check_field_location,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.ArrayType, ir_data.Expression],
+ _type_check_array_size,
+ skip_descendants_of={ir_data.AtomicType},
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.Field],
+ _type_check_field_existence_condition,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.RuntimeParameter],
+ _type_check_parameter,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.AtomicType],
+ _type_check_passed_parameters,
+ parameters={"errors": errors},
+ )
+ return errors
diff --git a/compiler/front_end/type_check_test.py b/compiler/front_end/type_check_test.py
index d308fed..995f20d 100644
--- a/compiler/front_end/type_check_test.py
+++ b/compiler/front_end/type_check_test.py
@@ -24,657 +24,1053 @@
class TypeAnnotationTest(unittest.TestCase):
- def _make_ir(self, emb_text):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- "m.emb",
- test_util.dict_file_reader({"m.emb": emb_text}),
- stop_before_step="annotate_types")
- assert not errors, errors
- return ir
+ def _make_ir(self, emb_text):
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ "m.emb",
+ test_util.dict_file_reader({"m.emb": emb_text}),
+ stop_before_step="annotate_types",
+ )
+ assert not errors, errors
+ return ir
- def test_adds_integer_constant_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] UInt:8[] y\n")
- self.assertEqual([], type_check.annotate_types(ir))
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "integer")
+ def test_adds_integer_constant_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " 1 [+1] UInt:8[] y\n"
+ )
+ self.assertEqual([], type_check.annotate_types(ir))
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(expression.type.WhichOneof("type"), "integer")
- def test_adds_boolean_constant_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+true] UInt:8[] y\n")
- self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)),
- ir_data_utils.IrDataSerializer(ir).to_json(indent=2))
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "boolean")
+ def test_adds_boolean_constant_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " 1 [+true] UInt:8[] y\n"
+ )
+ self.assertEqual(
+ [],
+ error.filter_errors(type_check.annotate_types(ir)),
+ ir_data_utils.IrDataSerializer(ir).to_json(indent=2),
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(expression.type.WhichOneof("type"), "boolean")
- def test_adds_enum_constant_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+Enum.VALUE] UInt x\n"
- "enum Enum:\n"
- " VALUE = 1\n")
- self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
- expression = ir.module[0].type[0].structure.field[0].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "enumeration")
- enum_type_name = expression.type.enumeration.name.canonical_name
- self.assertEqual(enum_type_name.module_file, "m.emb")
- self.assertEqual(enum_type_name.object_path[0], "Enum")
+ def test_adds_enum_constant_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+Enum.VALUE] UInt x\n"
+ "enum Enum:\n"
+ " VALUE = 1\n"
+ )
+ self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
+ expression = ir.module[0].type[0].structure.field[0].location.size
+ self.assertEqual(expression.type.WhichOneof("type"), "enumeration")
+ enum_type_name = expression.type.enumeration.name.canonical_name
+ self.assertEqual(enum_type_name.module_file, "m.emb")
+ self.assertEqual(enum_type_name.object_path[0], "Enum")
- def test_adds_enum_field_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] Enum x\n"
- " 1 [+x] UInt y\n"
- "enum Enum:\n"
- " VALUE = 1\n")
- self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "enumeration")
- enum_type_name = expression.type.enumeration.name.canonical_name
- self.assertEqual(enum_type_name.module_file, "m.emb")
- self.assertEqual(enum_type_name.object_path[0], "Enum")
+ def test_adds_enum_field_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] Enum x\n"
+ " 1 [+x] UInt y\n"
+ "enum Enum:\n"
+ " VALUE = 1\n"
+ )
+ self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(expression.type.WhichOneof("type"), "enumeration")
+ enum_type_name = expression.type.enumeration.name.canonical_name
+ self.assertEqual(enum_type_name.module_file, "m.emb")
+ self.assertEqual(enum_type_name.object_path[0], "Enum")
- def test_adds_integer_operation_types(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1+1] UInt:8[] y\n")
- self.assertEqual([], type_check.annotate_types(ir))
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "integer")
- self.assertEqual(expression.function.args[0].type.WhichOneof("type"),
- "integer")
- self.assertEqual(expression.function.args[1].type.WhichOneof("type"),
- "integer")
+ def test_adds_integer_operation_types(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " 1 [+1+1] UInt:8[] y\n"
+ )
+ self.assertEqual([], type_check.annotate_types(ir))
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(expression.type.WhichOneof("type"), "integer")
+ self.assertEqual(expression.function.args[0].type.WhichOneof("type"), "integer")
+ self.assertEqual(expression.function.args[1].type.WhichOneof("type"), "integer")
- def test_adds_enum_operation_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+Enum.VAL==Enum.VAL] UInt:8[] y\n"
- "enum Enum:\n"
- " VAL = 1\n")
- self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "boolean")
- self.assertEqual(expression.function.args[0].type.WhichOneof("type"),
- "enumeration")
- self.assertEqual(expression.function.args[1].type.WhichOneof("type"),
- "enumeration")
+ def test_adds_enum_operation_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+Enum.VAL==Enum.VAL] UInt:8[] y\n"
+ "enum Enum:\n"
+ " VAL = 1\n"
+ )
+ self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(expression.type.WhichOneof("type"), "boolean")
+ self.assertEqual(
+ expression.function.args[0].type.WhichOneof("type"), "enumeration"
+ )
+ self.assertEqual(
+ expression.function.args[1].type.WhichOneof("type"), "enumeration"
+ )
- def test_adds_enum_comparison_operation_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+Enum.VAL>=Enum.VAL] UInt:8[] y\n"
- "enum Enum:\n"
- " VAL = 1\n")
- self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "boolean")
- self.assertEqual(expression.function.args[0].type.WhichOneof("type"),
- "enumeration")
- self.assertEqual(expression.function.args[1].type.WhichOneof("type"),
- "enumeration")
+ def test_adds_enum_comparison_operation_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+Enum.VAL>=Enum.VAL] UInt:8[] y\n"
+ "enum Enum:\n"
+ " VAL = 1\n"
+ )
+ self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(expression.type.WhichOneof("type"), "boolean")
+ self.assertEqual(
+ expression.function.args[0].type.WhichOneof("type"), "enumeration"
+ )
+ self.assertEqual(
+ expression.function.args[1].type.WhichOneof("type"), "enumeration"
+ )
- def test_adds_integer_field_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+x] UInt:8[] y\n")
- self.assertEqual([], type_check.annotate_types(ir))
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "integer")
+ def test_adds_integer_field_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " 1 [+x] UInt:8[] y\n"
+ )
+ self.assertEqual([], type_check.annotate_types(ir))
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(expression.type.WhichOneof("type"), "integer")
- def test_adds_opaque_field_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] Bar x\n"
- " 1 [+x] UInt:8[] y\n"
- "struct Bar:\n"
- " 0 [+1] UInt z\n")
- self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "opaque")
+ def test_adds_opaque_field_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] Bar x\n"
+ " 1 [+x] UInt:8[] y\n"
+ "struct Bar:\n"
+ " 0 [+1] UInt z\n"
+ )
+ self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(expression.type.WhichOneof("type"), "opaque")
- def test_adds_opaque_field_type_for_array(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+x] UInt:8[] y\n")
- self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual(expression.type.WhichOneof("type"), "opaque")
+ def test_adds_opaque_field_type_for_array(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt:8[] x\n" " 1 [+x] UInt:8[] y\n"
+ )
+ self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(expression.type.WhichOneof("type"), "opaque")
- def test_error_on_bad_plus_operand_types(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1+true] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.function.args[1].source_location,
- "Right argument of operator '+' must be an integer.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_plus_operand_types(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1+true] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[1].source_location,
+ "Right argument of operator '+' must be an integer.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_bad_minus_operand_types(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1-true] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.function.args[1].source_location,
- "Right argument of operator '-' must be an integer.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_minus_operand_types(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1-true] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[1].source_location,
+ "Right argument of operator '-' must be an integer.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_bad_times_operand_types(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1*true] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.function.args[1].source_location,
- "Right argument of operator '*' must be an integer.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_times_operand_types(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1*true] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[1].source_location,
+ "Right argument of operator '*' must be an integer.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_bad_equality_left_operand(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+x==x] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.function.args[0].source_location,
- "Left argument of operator '==' must be an integer, "
- "boolean, or enum.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_equality_left_operand(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+x==x] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[0].source_location,
+ "Left argument of operator '==' must be an integer, "
+ "boolean, or enum.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_bad_equality_right_operand(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+1==x] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.function.args[1].source_location,
- "Right argument of operator '==' must be an integer, "
- "boolean, or enum.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_equality_right_operand(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+1==x] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[1].source_location,
+ "Right argument of operator '==' must be an integer, "
+ "boolean, or enum.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_equality_mismatched_operands_int_bool(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1==true] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Both arguments of operator '==' must have the same "
- "type.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_equality_mismatched_operands_int_bool(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1==true] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Both arguments of operator '==' must have the same " "type.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_mismatched_comparison_operands(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8 x\n"
- " 1 [+x>=Bar.BAR] UInt:8[] y\n"
- "enum Bar:\n"
- " BAR = 1\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Both arguments of operator '>=' must have the same "
- "type.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_mismatched_comparison_operands(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8 x\n"
+ " 1 [+x>=Bar.BAR] UInt:8[] y\n"
+ "enum Bar:\n"
+ " BAR = 1\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Both arguments of operator '>=' must have the same " "type.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_equality_mismatched_operands_bool_int(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+true!=1] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Both arguments of operator '!=' must have the same "
- "type.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_equality_mismatched_operands_bool_int(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+true!=1] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Both arguments of operator '!=' must have the same " "type.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_equality_mismatched_operands_enum_enum(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+Bar.BAR==Baz.BAZ] UInt:8[] y\n"
- "enum Bar:\n"
- " BAR = 1\n"
- "enum Baz:\n"
- " BAZ = 1\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Both arguments of operator '==' must have the same "
- "type.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_equality_mismatched_operands_enum_enum(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+Bar.BAR==Baz.BAZ] UInt:8[] y\n"
+ "enum Bar:\n"
+ " BAR = 1\n"
+ "enum Baz:\n"
+ " BAZ = 1\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Both arguments of operator '==' must have the same " "type.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_bad_choice_condition_operand(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+5 ? 0 : 1] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- condition_arg = expression.function.args[0]
- self.assertEqual([
- [error.error("m.emb", condition_arg.source_location,
- "Condition of operator '?:' must be a boolean.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_choice_condition_operand(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+5 ? 0 : 1] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ condition_arg = expression.function.args[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ condition_arg.source_location,
+ "Condition of operator '?:' must be a boolean.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_bad_choice_if_true_operand(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+true ? x : x] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- if_true_arg = expression.function.args[1]
- self.assertEqual([
- [error.error("m.emb", if_true_arg.source_location,
- "If-true clause of operator '?:' must be an integer, "
- "boolean, or enum.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_choice_if_true_operand(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+true ? x : x] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ if_true_arg = expression.function.args[1]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ if_true_arg.source_location,
+ "If-true clause of operator '?:' must be an integer, "
+ "boolean, or enum.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_choice_of_bools(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+true ? true : false] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
- self.assertEqual("boolean", expression.type.WhichOneof("type"))
+ def test_choice_of_bools(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+true ? true : false] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
+ self.assertEqual("boolean", expression.type.WhichOneof("type"))
- def test_choice_of_integers(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+true ? 0 : 100] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([], type_check.annotate_types(ir))
- self.assertEqual("integer", expression.type.WhichOneof("type"))
+ def test_choice_of_integers(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+true ? 0 : 100] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual([], type_check.annotate_types(ir))
+ self.assertEqual("integer", expression.type.WhichOneof("type"))
- def test_choice_of_enums(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] enum xx:\n"
- " XX = 1\n"
- " YY = 1\n"
- " 1 [+true ? Xx.XX : Xx.YY] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
- self.assertEqual("enumeration", expression.type.WhichOneof("type"))
- self.assertFalse(expression.type.enumeration.HasField("value"))
- self.assertEqual(
- "m.emb", expression.type.enumeration.name.canonical_name.module_file)
- self.assertEqual(
- "Foo", expression.type.enumeration.name.canonical_name.object_path[0])
- self.assertEqual(
- "Xx", expression.type.enumeration.name.canonical_name.object_path[1])
+ def test_choice_of_enums(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] enum xx:\n"
+ " XX = 1\n"
+ " YY = 1\n"
+ " 1 [+true ? Xx.XX : Xx.YY] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)))
+ self.assertEqual("enumeration", expression.type.WhichOneof("type"))
+ self.assertFalse(expression.type.enumeration.HasField("value"))
+ self.assertEqual(
+ "m.emb", expression.type.enumeration.name.canonical_name.module_file
+ )
+ self.assertEqual(
+ "Foo", expression.type.enumeration.name.canonical_name.object_path[0]
+ )
+ self.assertEqual(
+ "Xx", expression.type.enumeration.name.canonical_name.object_path[1]
+ )
- def test_error_on_bad_choice_mismatched_operands(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+true ? 0 : true] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "The if-true and if-false clauses of operator '?:' must "
- "have the same type.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_choice_mismatched_operands(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+true ? 0 : true] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "The if-true and if-false clauses of operator '?:' must "
+ "have the same type.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_bad_choice_mismatched_enum_operands(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+true ? Baz.BAZ : Bar.BAR] UInt:8[] y\n"
- "enum Bar:\n"
- " BAR = 1\n"
- "enum Baz:\n"
- " BAZ = 1\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "The if-true and if-false clauses of operator '?:' must "
- "have the same type.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_choice_mismatched_enum_operands(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+true ? Baz.BAZ : Bar.BAR] UInt:8[] y\n"
+ "enum Bar:\n"
+ " BAR = 1\n"
+ "enum Baz:\n"
+ " BAZ = 1\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "The if-true and if-false clauses of operator '?:' must "
+ "have the same type.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_bad_left_operand_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+true+1] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.function.args[0].source_location,
- "Left argument of operator '+' must be an integer.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_left_operand_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+true+1] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[0].source_location,
+ "Left argument of operator '+' must be an integer.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_opaque_operand_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+x+1] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.function.args[0].source_location,
- "Left argument of operator '+' must be an integer.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_opaque_operand_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+x+1] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[0].source_location,
+ "Left argument of operator '+' must be an integer.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_bad_left_comparison_operand_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+true<1] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error(
- "m.emb", expression.function.args[0].source_location,
- "Left argument of operator '<' must be an integer or enum.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_left_comparison_operand_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+true<1] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[0].source_location,
+ "Left argument of operator '<' must be an integer or enum.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_bad_right_comparison_operand_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1>=true] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error(
- "m.emb", expression.function.args[1].source_location,
- "Right argument of operator '>=' must be an integer or enum.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_right_comparison_operand_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1>=true] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[1].source_location,
+ "Right argument of operator '>=' must be an integer or enum.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_bad_boolean_operand_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " 1 [+1&&true] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.function.args[0].source_location,
- "Left argument of operator '&&' must be a boolean.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_boolean_operand_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1&&true] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[0].source_location,
+ "Left argument of operator '&&' must be a boolean.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_max_return_type(self):
- ir = self._make_ir("struct Foo:\n"
- " $max(1, 2, 3) [+1] UInt:8[] x\n")
- expression = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual([], type_check.annotate_types(ir))
- self.assertEqual("integer", expression.type.WhichOneof("type"))
+ def test_max_return_type(self):
+ ir = self._make_ir("struct Foo:\n" " $max(1, 2, 3) [+1] UInt:8[] x\n")
+ expression = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual([], type_check.annotate_types(ir))
+ self.assertEqual("integer", expression.type.WhichOneof("type"))
- def test_error_on_bad_max_argument(self):
- ir = self._make_ir("struct Foo:\n"
- " $max(Bar.XX) [+1] UInt:8[] x\n"
- "enum Bar:\n"
- " XX = 0\n")
- expression = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual([
- [error.error("m.emb", expression.function.args[0].source_location,
- "Argument 0 of function '$max' must be an integer.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_bad_max_argument(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " $max(Bar.XX) [+1] UInt:8[] x\n"
+ "enum Bar:\n"
+ " XX = 0\n"
+ )
+ expression = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[0].source_location,
+ "Argument 0 of function '$max' must be an integer.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_no_max_argument(self):
- ir = self._make_ir("struct Foo:\n"
- " $max() [+1] UInt:8[] x\n")
- expression = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Function '$max' requires at least 1 argument.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_no_max_argument(self):
+ ir = self._make_ir("struct Foo:\n" " $max() [+1] UInt:8[] x\n")
+ expression = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Function '$max' requires at least 1 argument.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_upper_bound_return_type(self):
- ir = self._make_ir("struct Foo:\n"
- " $upper_bound(3) [+1] UInt:8[] x\n")
- expression = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual([], type_check.annotate_types(ir))
- self.assertEqual("integer", expression.type.WhichOneof("type"))
+ def test_upper_bound_return_type(self):
+ ir = self._make_ir("struct Foo:\n" " $upper_bound(3) [+1] UInt:8[] x\n")
+ expression = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual([], type_check.annotate_types(ir))
+ self.assertEqual("integer", expression.type.WhichOneof("type"))
- def test_upper_bound_too_few_arguments(self):
- ir = self._make_ir("struct Foo:\n"
- " $upper_bound() [+1] UInt:8[] x\n")
- expression = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Function '$upper_bound' requires exactly 1 argument.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_upper_bound_too_few_arguments(self):
+ ir = self._make_ir("struct Foo:\n" " $upper_bound() [+1] UInt:8[] x\n")
+ expression = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Function '$upper_bound' requires exactly 1 argument.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_upper_bound_too_many_arguments(self):
- ir = self._make_ir("struct Foo:\n"
- " $upper_bound(1, 2) [+1] UInt:8[] x\n")
- expression = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Function '$upper_bound' requires exactly 1 argument.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_upper_bound_too_many_arguments(self):
+ ir = self._make_ir("struct Foo:\n" " $upper_bound(1, 2) [+1] UInt:8[] x\n")
+ expression = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Function '$upper_bound' requires exactly 1 argument.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_upper_bound_wrong_argument_type(self):
- ir = self._make_ir("struct Foo:\n"
- " $upper_bound(Bar.XX) [+1] UInt:8[] x\n"
- "enum Bar:\n"
- " XX = 0\n")
- expression = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual([
- [error.error(
- "m.emb", expression.function.args[0].source_location,
- "Argument 0 of function '$upper_bound' must be an integer.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_upper_bound_wrong_argument_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " $upper_bound(Bar.XX) [+1] UInt:8[] x\n"
+ "enum Bar:\n"
+ " XX = 0\n"
+ )
+ expression = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[0].source_location,
+ "Argument 0 of function '$upper_bound' must be an integer.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_lower_bound_return_type(self):
- ir = self._make_ir("struct Foo:\n"
- " $lower_bound(3) [+1] UInt:8[] x\n")
- expression = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual([], type_check.annotate_types(ir))
- self.assertEqual("integer", expression.type.WhichOneof("type"))
+ def test_lower_bound_return_type(self):
+ ir = self._make_ir("struct Foo:\n" " $lower_bound(3) [+1] UInt:8[] x\n")
+ expression = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual([], type_check.annotate_types(ir))
+ self.assertEqual("integer", expression.type.WhichOneof("type"))
- def test_lower_bound_too_few_arguments(self):
- ir = self._make_ir("struct Foo:\n"
- " $lower_bound() [+1] UInt:8[] x\n")
- expression = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Function '$lower_bound' requires exactly 1 argument.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_lower_bound_too_few_arguments(self):
+ ir = self._make_ir("struct Foo:\n" " $lower_bound() [+1] UInt:8[] x\n")
+ expression = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Function '$lower_bound' requires exactly 1 argument.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_lower_bound_too_many_arguments(self):
- ir = self._make_ir("struct Foo:\n"
- " $lower_bound(1, 2) [+1] UInt:8[] x\n")
- expression = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Function '$lower_bound' requires exactly 1 argument.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_lower_bound_too_many_arguments(self):
+ ir = self._make_ir("struct Foo:\n" " $lower_bound(1, 2) [+1] UInt:8[] x\n")
+ expression = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Function '$lower_bound' requires exactly 1 argument.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_lower_bound_wrong_argument_type(self):
- ir = self._make_ir("struct Foo:\n"
- " $lower_bound(Bar.XX) [+1] UInt:8[] x\n"
- "enum Bar:\n"
- " XX = 0\n")
- expression = ir.module[0].type[0].structure.field[0].location.start
- self.assertEqual([
- [error.error(
- "m.emb", expression.function.args[0].source_location,
- "Argument 0 of function '$lower_bound' must be an integer.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_lower_bound_wrong_argument_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " $lower_bound(Bar.XX) [+1] UInt:8[] x\n"
+ "enum Bar:\n"
+ " XX = 0\n"
+ )
+ expression = ir.module[0].type[0].structure.field[0].location.start
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[0].source_location,
+ "Argument 0 of function '$lower_bound' must be an integer.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_static_reference_to_physical_field(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " let y = Foo.x\n")
- static_ref = ir.module[0].type[0].structure.field[1].read_transform
- physical_field = ir.module[0].type[0].structure.field[0]
- self.assertEqual([
- [error.error("m.emb", static_ref.source_location,
- "Static references to physical fields are not allowed."),
- error.note("m.emb", physical_field.source_location,
- "x is a physical field.")]
- ], type_check.annotate_types(ir))
+ def test_error_static_reference_to_physical_field(self):
+ ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n" " let y = Foo.x\n")
+ static_ref = ir.module[0].type[0].structure.field[1].read_transform
+ physical_field = ir.module[0].type[0].structure.field[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ static_ref.source_location,
+ "Static references to physical fields are not allowed.",
+ ),
+ error.note(
+ "m.emb",
+ physical_field.source_location,
+ "x is a physical field.",
+ ),
+ ]
+ ],
+ type_check.annotate_types(ir),
+ )
- def test_error_on_non_field_argument_to_has(self):
- ir = self._make_ir("struct Foo:\n"
- " if $present(0):\n"
- " 0 [+1] UInt x\n")
- expression = ir.module[0].type[0].structure.field[0].existence_condition
- self.assertEqual([
- [error.error("m.emb", expression.function.args[0].source_location,
- "Argument 0 of function '$present' must be a field.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_non_field_argument_to_has(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " if $present(0):\n" " 0 [+1] UInt x\n"
+ )
+ expression = ir.module[0].type[0].structure.field[0].existence_condition
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.function.args[0].source_location,
+ "Argument 0 of function '$present' must be a field.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_no_argument_has(self):
- ir = self._make_ir("struct Foo:\n"
- " if $present():\n"
- " 0 [+1] UInt x\n")
- expression = ir.module[0].type[0].structure.field[0].existence_condition
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Function '$present' requires exactly 1 argument.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_no_argument_has(self):
+ ir = self._make_ir("struct Foo:\n" " if $present():\n" " 0 [+1] UInt x\n")
+ expression = ir.module[0].type[0].structure.field[0].existence_condition
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Function '$present' requires exactly 1 argument.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_error_on_too_many_argument_has(self):
- ir = self._make_ir("struct Foo:\n"
- " if $present(y, y):\n"
- " 0 [+1] UInt x\n"
- " 1 [+1] UInt y\n")
- expression = ir.module[0].type[0].structure.field[0].existence_condition
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Function '$present' requires exactly 1 argument.")]
- ], error.filter_errors(type_check.annotate_types(ir)))
+ def test_error_on_too_many_argument_has(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " if $present(y, y):\n"
+ " 0 [+1] UInt x\n"
+ " 1 [+1] UInt y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[0].existence_condition
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Function '$present' requires exactly 1 argument.",
+ )
+ ]
+ ],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
- def test_checks_that_parameters_are_atomic_types(self):
- ir = self._make_ir("struct Foo(y: UInt:8[1]):\n"
- " 0 [+1] UInt x\n")
- error_parameter = ir.module[0].type[0].runtime_parameter[0]
- error_location = error_parameter.physical_type_alias.source_location
- self.assertEqual(
- [[error.error("m.emb", error_location,
- "Parameters cannot be arrays.")]],
- error.filter_errors(type_check.annotate_types(ir)))
+ def test_checks_that_parameters_are_atomic_types(self):
+ ir = self._make_ir("struct Foo(y: UInt:8[1]):\n" " 0 [+1] UInt x\n")
+ error_parameter = ir.module[0].type[0].runtime_parameter[0]
+ error_location = error_parameter.physical_type_alias.source_location
+ self.assertEqual(
+ [[error.error("m.emb", error_location, "Parameters cannot be arrays.")]],
+ error.filter_errors(type_check.annotate_types(ir)),
+ )
class TypeCheckTest(unittest.TestCase):
- def _make_ir(self, emb_text):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- "m.emb",
- test_util.dict_file_reader({"m.emb": emb_text}),
- stop_before_step="check_types")
- assert not errors, errors
- return ir
+ def _make_ir(self, emb_text):
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ "m.emb",
+ test_util.dict_file_reader({"m.emb": emb_text}),
+ stop_before_step="check_types",
+ )
+ assert not errors, errors
+ return ir
- def test_error_on_opaque_type_in_field_start(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " x [+10] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.start
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Start of field must be an integer.")]
- ], type_check.check_types(ir))
+ def test_error_on_opaque_type_in_field_start(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " x [+10] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.start
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Start of field must be an integer.",
+ )
+ ]
+ ],
+ type_check.check_types(ir),
+ )
- def test_error_on_boolean_type_in_field_start(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " true [+10] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.start
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Start of field must be an integer.")]
- ], type_check.check_types(ir))
+ def test_error_on_boolean_type_in_field_start(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " true [+10] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.start
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Start of field must be an integer.",
+ )
+ ]
+ ],
+ type_check.check_types(ir),
+ )
- def test_error_on_opaque_type_in_field_size(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+x] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Size of field must be an integer.")]
- ], type_check.check_types(ir))
+ def test_error_on_opaque_type_in_field_size(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+x] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Size of field must be an integer.",
+ )
+ ]
+ ],
+ type_check.check_types(ir),
+ )
- def test_error_on_boolean_type_in_field_size(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+true] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].location.size
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Size of field must be an integer.")]
- ], type_check.check_types(ir))
+ def test_error_on_boolean_type_in_field_size(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+true] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].location.size
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Size of field must be an integer.",
+ )
+ ]
+ ],
+ type_check.check_types(ir),
+ )
- def test_error_on_opaque_type_in_array_size(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+9] UInt:8[x] y\n")
- expression = (ir.module[0].type[0].structure.field[1].type.array_type.
- element_count)
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Array size must be an integer.")]
- ], type_check.check_types(ir))
+ def test_error_on_opaque_type_in_array_size(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+9] UInt:8[x] y\n"
+ )
+ expression = (
+ ir.module[0].type[0].structure.field[1].type.array_type.element_count
+ )
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Array size must be an integer.",
+ )
+ ]
+ ],
+ type_check.check_types(ir),
+ )
- def test_error_on_boolean_type_in_array_size(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " 1 [+9] UInt:8[true] y\n")
- expression = (ir.module[0].type[0].structure.field[1].type.array_type.
- element_count)
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Array size must be an integer.")]
- ], type_check.check_types(ir))
+ def test_error_on_boolean_type_in_array_size(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " 1 [+9] UInt:8[true] y\n"
+ )
+ expression = (
+ ir.module[0].type[0].structure.field[1].type.array_type.element_count
+ )
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Array size must be an integer.",
+ )
+ ]
+ ],
+ type_check.check_types(ir),
+ )
- def test_error_on_integer_type_in_existence_condition(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt:8[] x\n"
- " if 1:\n"
- " 1 [+9] UInt:8[] y\n")
- expression = ir.module[0].type[0].structure.field[1].existence_condition
- self.assertEqual([
- [error.error("m.emb", expression.source_location,
- "Existence condition must be a boolean.")]
- ], type_check.check_types(ir))
+ def test_error_on_integer_type_in_existence_condition(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt:8[] x\n"
+ " if 1:\n"
+ " 1 [+9] UInt:8[] y\n"
+ )
+ expression = ir.module[0].type[0].structure.field[1].existence_condition
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ expression.source_location,
+ "Existence condition must be a boolean.",
+ )
+ ]
+ ],
+ type_check.check_types(ir),
+ )
- def test_error_on_non_integer_non_enum_parameter(self):
- ir = self._make_ir("struct Foo(f: Flag):\n"
- " 0 [+1] UInt:8[] x\n")
- parameter = ir.module[0].type[0].runtime_parameter[0]
- self.assertEqual(
- [[error.error("m.emb", parameter.physical_type_alias.source_location,
- "Runtime parameters must be integer or enum.")]],
- type_check.check_types(ir))
+ def test_error_on_non_integer_non_enum_parameter(self):
+ ir = self._make_ir("struct Foo(f: Flag):\n" " 0 [+1] UInt:8[] x\n")
+ parameter = ir.module[0].type[0].runtime_parameter[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ parameter.physical_type_alias.source_location,
+ "Runtime parameters must be integer or enum.",
+ )
+ ]
+ ],
+ type_check.check_types(ir),
+ )
- def test_error_on_failure_to_pass_parameter(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] Bar b\n"
- "struct Bar(f: UInt:6):\n"
- " 0 [+1] UInt:8[] x\n")
- type_ir = ir.module[0].type[0].structure.field[0].type
- bar = ir.module[0].type[1]
- self.assertEqual(
- [[
- error.error("m.emb", type_ir.source_location,
- "Type Bar requires 1 parameter; 0 parameters given."),
- error.note("m.emb", bar.source_location,
- "Definition of type Bar.")
- ]],
- type_check.check_types(ir))
+ def test_error_on_failure_to_pass_parameter(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] Bar b\n"
+ "struct Bar(f: UInt:6):\n"
+ " 0 [+1] UInt:8[] x\n"
+ )
+ type_ir = ir.module[0].type[0].structure.field[0].type
+ bar = ir.module[0].type[1]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ type_ir.source_location,
+ "Type Bar requires 1 parameter; 0 parameters given.",
+ ),
+ error.note("m.emb", bar.source_location, "Definition of type Bar."),
+ ]
+ ],
+ type_check.check_types(ir),
+ )
- def test_error_on_passing_unneeded_parameter(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] Bar(1) b\n"
- "struct Bar:\n"
- " 0 [+1] UInt:8[] x\n")
- type_ir = ir.module[0].type[0].structure.field[0].type
- bar = ir.module[0].type[1]
- self.assertEqual(
- [[
- error.error("m.emb", type_ir.source_location,
- "Type Bar requires 0 parameters; 1 parameter given."),
- error.note("m.emb", bar.source_location,
- "Definition of type Bar.")
- ]],
- type_check.check_types(ir))
+ def test_error_on_passing_unneeded_parameter(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] Bar(1) b\n"
+ "struct Bar:\n"
+ " 0 [+1] UInt:8[] x\n"
+ )
+ type_ir = ir.module[0].type[0].structure.field[0].type
+ bar = ir.module[0].type[1]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ type_ir.source_location,
+ "Type Bar requires 0 parameters; 1 parameter given.",
+ ),
+ error.note("m.emb", bar.source_location, "Definition of type Bar."),
+ ]
+ ],
+ type_check.check_types(ir),
+ )
- def test_error_on_passing_wrong_parameter_type(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] Bar(1) b\n"
- "enum Baz:\n"
- " QUX = 1\n"
- "struct Bar(n: Baz):\n"
- " 0 [+1] UInt:8[] x\n")
- type_ir = ir.module[0].type[0].structure.field[0].type
- usage_parameter_ir = type_ir.atomic_type.runtime_parameter[0]
- source_parameter_ir = ir.module[0].type[2].runtime_parameter[0]
- self.assertEqual(
- [[
- error.error("m.emb", usage_parameter_ir.source_location,
- "Parameter 0 of type Bar must be Baz, not integer."),
- error.note("m.emb", source_parameter_ir.source_location,
- "Parameter 0 of Bar.")
- ]],
- type_check.check_types(ir))
+ def test_error_on_passing_wrong_parameter_type(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] Bar(1) b\n"
+ "enum Baz:\n"
+ " QUX = 1\n"
+ "struct Bar(n: Baz):\n"
+ " 0 [+1] UInt:8[] x\n"
+ )
+ type_ir = ir.module[0].type[0].structure.field[0].type
+ usage_parameter_ir = type_ir.atomic_type.runtime_parameter[0]
+ source_parameter_ir = ir.module[0].type[2].runtime_parameter[0]
+ self.assertEqual(
+ [
+ [
+ error.error(
+ "m.emb",
+ usage_parameter_ir.source_location,
+ "Parameter 0 of type Bar must be Baz, not integer.",
+ ),
+ error.note(
+ "m.emb",
+ source_parameter_ir.source_location,
+ "Parameter 0 of Bar.",
+ ),
+ ]
+ ],
+ type_check.check_types(ir),
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/front_end/write_inference.py b/compiler/front_end/write_inference.py
index 7cfe7df..8353306 100644
--- a/compiler/front_end/write_inference.py
+++ b/compiler/front_end/write_inference.py
@@ -23,262 +23,269 @@
def _find_field_reference_path(expression):
- """Returns a path to a field reference, or None.
+ """Returns a path to a field reference, or None.
- If the provided expression contains exactly one field_reference,
- _find_field_reference_path will return a list of indexes, such that
- recursively reading the index'th element of expression.function.args will find
- the field_reference. For example, for:
+ If the provided expression contains exactly one field_reference,
+ _find_field_reference_path will return a list of indexes, such that
+ recursively reading the index'th element of expression.function.args will find
+ the field_reference. For example, for:
- 5 + (x * 2)
+ 5 + (x * 2)
- _find_field_reference_path will return [1, 0]: from the top-level `+`
- expression, arg 1 is the `x * 2` expression, and `x` is arg 0 of the `*`
- expression.
+ _find_field_reference_path will return [1, 0]: from the top-level `+`
+ expression, arg 1 is the `x * 2` expression, and `x` is arg 0 of the `*`
+ expression.
- Arguments:
- expression: an ir_data.Expression to walk
+ Arguments:
+ expression: an ir_data.Expression to walk
- Returns:
- A list of indexes to find a field_reference, or None.
- """
- found, indexes = _recursively_find_field_reference_path(expression)
- if found == 1:
- return indexes
- else:
- return None
+ Returns:
+ A list of indexes to find a field_reference, or None.
+ """
+ found, indexes = _recursively_find_field_reference_path(expression)
+ if found == 1:
+ return indexes
+ else:
+ return None
def _recursively_find_field_reference_path(expression):
- """Recursive implementation of _find_field_reference_path."""
- if expression.WhichOneof("expression") == "field_reference":
- return 1, []
- elif expression.WhichOneof("expression") == "function":
- field_count = 0
- path = []
- for index in range(len(expression.function.args)):
- arg = expression.function.args[index]
- arg_result = _recursively_find_field_reference_path(arg)
- arg_field_count, arg_path = arg_result
- if arg_field_count == 1 and field_count == 0:
- path = [index] + arg_path
- field_count += arg_field_count
- if field_count == 1:
- return field_count, path
+ """Recursive implementation of _find_field_reference_path."""
+ if expression.WhichOneof("expression") == "field_reference":
+ return 1, []
+ elif expression.WhichOneof("expression") == "function":
+ field_count = 0
+ path = []
+ for index in range(len(expression.function.args)):
+ arg = expression.function.args[index]
+ arg_result = _recursively_find_field_reference_path(arg)
+ arg_field_count, arg_path = arg_result
+ if arg_field_count == 1 and field_count == 0:
+ path = [index] + arg_path
+ field_count += arg_field_count
+ if field_count == 1:
+ return field_count, path
+ else:
+ return field_count, []
else:
- return field_count, []
- else:
- return 0, []
+ return 0, []
def _invert_expression(expression, ir):
- """For the given expression, searches for an algebraic inverse expression.
+ """For the given expression, searches for an algebraic inverse expression.
- That is, it takes the notional equation:
+ That is, it takes the notional equation:
- $logical_value = expression
+ $logical_value = expression
- and, if there is exactly one `field_reference` in `expression`, it will
- attempt to solve the equation for that field. For example, if the expression
- is `x + 1`, it will iteratively transform:
+ and, if there is exactly one `field_reference` in `expression`, it will
+ attempt to solve the equation for that field. For example, if the expression
+ is `x + 1`, it will iteratively transform:
- $logical_value = x + 1
- $logical_value - 1 = x + 1 - 1
- $logical_value - 1 = x
+ $logical_value = x + 1
+ $logical_value - 1 = x + 1 - 1
+ $logical_value - 1 = x
- and finally return `x` and `$logical_value - 1`.
+ and finally return `x` and `$logical_value - 1`.
- The purpose of this transformation is to find an assignment statement that can
- be used to write back through certain virtual fields. E.g., given:
+ The purpose of this transformation is to find an assignment statement that can
+ be used to write back through certain virtual fields. E.g., given:
- struct Foo:
- 0 [+1] UInt raw_value
- let actual_value = raw_value + 100
+ struct Foo:
+ 0 [+1] UInt raw_value
+ let actual_value = raw_value + 100
- it should be possible to write a value to the `actual_value` field, and have
- it set `raw_value` to the appropriate value.
+ it should be possible to write a value to the `actual_value` field, and have
+ it set `raw_value` to the appropriate value.
- Arguments:
- expression: an ir_data.Expression to be inverted.
- ir: the full IR, for looking up symbols.
+ Arguments:
+ expression: an ir_data.Expression to be inverted.
+ ir: the full IR, for looking up symbols.
- Returns:
- (field_reference, inverse_expression) if expression can be inverted,
- otherwise None.
- """
- reference_path = _find_field_reference_path(expression)
- if reference_path is None:
- return None
- subexpression = expression
- result = ir_data.Expression(
- builtin_reference=ir_data.Reference(
- canonical_name=ir_data.CanonicalName(
- module_file="",
- object_path=["$logical_value"]
- ),
- source_name=[ir_data.Word(
- text="$logical_value",
- source_location=ir_data.Location(is_synthetic=True)
- )],
- source_location=ir_data.Location(is_synthetic=True)
- ),
- type=expression.type,
- source_location=ir_data.Location(is_synthetic=True)
- )
-
- # This loop essentially starts with:
- #
- # f(g(x)) == $logical_value
- #
- # and ends with
- #
- # x == g_inv(f_inv($logical_value))
- #
- # At each step, `subexpression` has one layer removed, and `result` has a
- # corresponding inverse function applied. So, for example, it might start
- # with:
- #
- # 2 + ((3 - x) - 10) == $logical_value
- #
- # On each iteration, `subexpression` and `result` will become:
- #
- # (3 - x) - 10 == $logical_value - 2 [subtract 2 from both sides]
- # (3 - x) == ($logical_value - 2) + 10 [add 10 to both sides]
- # x == 3 - (($logical_value - 2) + 10) [subtract both sides from 3]
- #
- # This is an extremely limited algebraic solver, but it covers common-enough
- # cases.
- #
- # Note that any equation that can be solved here becomes part of Emboss's
- # contract, forever, so be conservative in expanding its solving capabilities!
- for index in reference_path:
- if subexpression.function.function == ir_data.FunctionMapping.ADDITION:
- result = ir_data.Expression(
- function=ir_data.Function(
- function=ir_data.FunctionMapping.SUBTRACTION,
- args=[
- result,
- subexpression.function.args[1 - index],
- ]
- ),
- type=ir_data.ExpressionType(integer=ir_data.IntegerType())
- )
- elif subexpression.function.function == ir_data.FunctionMapping.SUBTRACTION:
- if index == 0:
- result = ir_data.Expression(
- function=ir_data.Function(
- function=ir_data.FunctionMapping.ADDITION,
- args=[
- result,
- subexpression.function.args[1],
- ]
+ Returns:
+ (field_reference, inverse_expression) if expression can be inverted,
+ otherwise None.
+ """
+ reference_path = _find_field_reference_path(expression)
+ if reference_path is None:
+ return None
+ subexpression = expression
+ result = ir_data.Expression(
+ builtin_reference=ir_data.Reference(
+ canonical_name=ir_data.CanonicalName(
+ module_file="", object_path=["$logical_value"]
),
- type=ir_data.ExpressionType(integer=ir_data.IntegerType())
- )
- else:
- result = ir_data.Expression(
- function=ir_data.Function(
- function=ir_data.FunctionMapping.SUBTRACTION,
- args=[
- subexpression.function.args[0],
- result,
- ]
- ),
- type=ir_data.ExpressionType(integer=ir_data.IntegerType())
- )
- else:
- return None
- subexpression = subexpression.function.args[index]
- expression_bounds.compute_constraints_of_expression(result, ir)
- return subexpression, result
+ source_name=[
+ ir_data.Word(
+ text="$logical_value",
+ source_location=ir_data.Location(is_synthetic=True),
+ )
+ ],
+ source_location=ir_data.Location(is_synthetic=True),
+ ),
+ type=expression.type,
+ source_location=ir_data.Location(is_synthetic=True),
+ )
+
+ # This loop essentially starts with:
+ #
+ # f(g(x)) == $logical_value
+ #
+ # and ends with
+ #
+ # x == g_inv(f_inv($logical_value))
+ #
+ # At each step, `subexpression` has one layer removed, and `result` has a
+ # corresponding inverse function applied. So, for example, it might start
+ # with:
+ #
+ # 2 + ((3 - x) - 10) == $logical_value
+ #
+ # On each iteration, `subexpression` and `result` will become:
+ #
+ # (3 - x) - 10 == $logical_value - 2 [subtract 2 from both sides]
+ # (3 - x) == ($logical_value - 2) + 10 [add 10 to both sides]
+ # x == 3 - (($logical_value - 2) + 10) [subtract both sides from 3]
+ #
+ # This is an extremely limited algebraic solver, but it covers common-enough
+ # cases.
+ #
+ # Note that any equation that can be solved here becomes part of Emboss's
+ # contract, forever, so be conservative in expanding its solving capabilities!
+ for index in reference_path:
+ if subexpression.function.function == ir_data.FunctionMapping.ADDITION:
+ result = ir_data.Expression(
+ function=ir_data.Function(
+ function=ir_data.FunctionMapping.SUBTRACTION,
+ args=[
+ result,
+ subexpression.function.args[1 - index],
+ ],
+ ),
+ type=ir_data.ExpressionType(integer=ir_data.IntegerType()),
+ )
+ elif subexpression.function.function == ir_data.FunctionMapping.SUBTRACTION:
+ if index == 0:
+ result = ir_data.Expression(
+ function=ir_data.Function(
+ function=ir_data.FunctionMapping.ADDITION,
+ args=[
+ result,
+ subexpression.function.args[1],
+ ],
+ ),
+ type=ir_data.ExpressionType(integer=ir_data.IntegerType()),
+ )
+ else:
+ result = ir_data.Expression(
+ function=ir_data.Function(
+ function=ir_data.FunctionMapping.SUBTRACTION,
+ args=[
+ subexpression.function.args[0],
+ result,
+ ],
+ ),
+ type=ir_data.ExpressionType(integer=ir_data.IntegerType()),
+ )
+ else:
+ return None
+ subexpression = subexpression.function.args[index]
+ expression_bounds.compute_constraints_of_expression(result, ir)
+ return subexpression, result
def _add_write_method(field, ir):
- """Adds an appropriate write_method to field, if applicable.
+ """Adds an appropriate write_method to field, if applicable.
- Currently, the "alias" write_method will be added for virtual fields of the
- form `let v = some_field_reference` when `some_field_reference` is a physical
- field or a writeable alias. The "physical" write_method will be added for
- physical fields. The "transform" write_method will be added when the virtual
- field's value is an easily-invertible function of a single writeable field.
- All other fields will have the "read_only" write_method; i.e., they will not
- be writeable.
+ Currently, the "alias" write_method will be added for virtual fields of the
+ form `let v = some_field_reference` when `some_field_reference` is a physical
+ field or a writeable alias. The "physical" write_method will be added for
+ physical fields. The "transform" write_method will be added when the virtual
+ field's value is an easily-invertible function of a single writeable field.
+ All other fields will have the "read_only" write_method; i.e., they will not
+ be writeable.
- Arguments:
- field: an ir_data.Field to which to add a write_method.
- ir: The IR in which to look up field_references.
+ Arguments:
+ field: an ir_data.Field to which to add a write_method.
+ ir: The IR in which to look up field_references.
- Returns:
- None
- """
- if field.HasField("write_method"):
- # Do not recompute anything.
- return
+ Returns:
+ None
+ """
+ if field.HasField("write_method"):
+ # Do not recompute anything.
+ return
- if not ir_util.field_is_virtual(field):
- # If the field is not virtual, writes are physical.
- ir_data_utils.builder(field).write_method.physical = True
- return
+ if not ir_util.field_is_virtual(field):
+ # If the field is not virtual, writes are physical.
+ ir_data_utils.builder(field).write_method.physical = True
+ return
- field_checker = ir_data_utils.reader(field)
- field_builder = ir_data_utils.builder(field)
+ field_checker = ir_data_utils.reader(field)
+ field_builder = ir_data_utils.builder(field)
- # A virtual field cannot be a direct alias if it has an additional
- # requirement.
- requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
- if (field_checker.read_transform.WhichOneof("expression") != "field_reference" or
- requires_attr is not None):
- inverse = _invert_expression(field.read_transform, ir)
- if inverse:
- field_reference, function_body = inverse
- referenced_field = ir_util.find_object(
- field_reference.field_reference.path[-1], ir)
- if not isinstance(referenced_field, ir_data.Field):
- reference_is_read_only = True
- else:
- _add_write_method(referenced_field, ir)
- reference_is_read_only = referenced_field.write_method.read_only
- if not reference_is_read_only:
- field_builder.write_method.transform.destination.CopyFrom(
- field_reference.field_reference)
- field_builder.write_method.transform.function_body.CopyFrom(function_body)
- else:
- # If the virtual field's expression is invertible, but its target field
- # is read-only, it is also read-only.
+ # A virtual field cannot be a direct alias if it has an additional
+ # requirement.
+ requires_attr = ir_util.get_attribute(field.attribute, attributes.REQUIRES)
+ if (
+ field_checker.read_transform.WhichOneof("expression") != "field_reference"
+ or requires_attr is not None
+ ):
+ inverse = _invert_expression(field.read_transform, ir)
+ if inverse:
+ field_reference, function_body = inverse
+ referenced_field = ir_util.find_object(
+ field_reference.field_reference.path[-1], ir
+ )
+ if not isinstance(referenced_field, ir_data.Field):
+ reference_is_read_only = True
+ else:
+ _add_write_method(referenced_field, ir)
+ reference_is_read_only = referenced_field.write_method.read_only
+ if not reference_is_read_only:
+ field_builder.write_method.transform.destination.CopyFrom(
+ field_reference.field_reference
+ )
+ field_builder.write_method.transform.function_body.CopyFrom(
+ function_body
+ )
+ else:
+ # If the virtual field's expression is invertible, but its target field
+ # is read-only, it is also read-only.
+ field_builder.write_method.read_only = True
+ else:
+ # If the virtual field's expression is not invertible, it is
+ # read-only.
+ field_builder.write_method.read_only = True
+ return
+
+ referenced_field = ir_util.find_object(
+ field.read_transform.field_reference.path[-1], ir
+ )
+ if not isinstance(referenced_field, ir_data.Field):
+ # If the virtual field aliases a non-field (i.e., a parameter), it is
+ # read-only.
field_builder.write_method.read_only = True
- else:
- # If the virtual field's expression is not invertible, it is
- # read-only.
- field_builder.write_method.read_only = True
- return
+ return
- referenced_field = ir_util.find_object(
- field.read_transform.field_reference.path[-1], ir)
- if not isinstance(referenced_field, ir_data.Field):
- # If the virtual field aliases a non-field (i.e., a parameter), it is
- # read-only.
- field_builder.write_method.read_only = True
- return
+ _add_write_method(referenced_field, ir)
+ if referenced_field.write_method.read_only:
+ # If the virtual field directly aliases a read-only field, it is read-only.
+ field_builder.write_method.read_only = True
+ return
- _add_write_method(referenced_field, ir)
- if referenced_field.write_method.read_only:
- # If the virtual field directly aliases a read-only field, it is read-only.
- field_builder.write_method.read_only = True
- return
-
- # Otherwise, it can be written as a direct alias.
- field_builder.write_method.alias.CopyFrom(
- field.read_transform.field_reference)
+ # Otherwise, it can be written as a direct alias.
+ field_builder.write_method.alias.CopyFrom(field.read_transform.field_reference)
def set_write_methods(ir):
- """Sets the write_method member of all ir_data.Fields in ir.
+ """Sets the write_method member of all ir_data.Fields in ir.
- Arguments:
- ir: The IR to which to add write_methods.
+ Arguments:
+ ir: The IR to which to add write_methods.
- Returns:
- A list of errors, or an empty list.
- """
- traverse_ir.fast_traverse_ir_top_down(ir, [ir_data.Field], _add_write_method)
- return []
+ Returns:
+ A list of errors, or an empty list.
+ """
+ traverse_ir.fast_traverse_ir_top_down(ir, [ir_data.Field], _add_write_method)
+ return []
diff --git a/compiler/front_end/write_inference_test.py b/compiler/front_end/write_inference_test.py
index d1de5f2..c6afa2f 100644
--- a/compiler/front_end/write_inference_test.py
+++ b/compiler/front_end/write_inference_test.py
@@ -23,194 +23,193 @@
class WriteInferenceTest(unittest.TestCase):
- def _make_ir(self, emb_text):
- ir, unused_debug_info, errors = glue.parse_emboss_file(
- "m.emb",
- test_util.dict_file_reader({"m.emb": emb_text}),
- stop_before_step="set_write_methods")
- assert not errors, errors
- return ir
+ def _make_ir(self, emb_text):
+ ir, unused_debug_info, errors = glue.parse_emboss_file(
+ "m.emb",
+ test_util.dict_file_reader({"m.emb": emb_text}),
+ stop_before_step="set_write_methods",
+ )
+ assert not errors, errors
+ return ir
- def test_adds_physical_write_method(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- self.assertTrue(
- ir.module[0].type[0].structure.field[0].write_method.physical)
+ def test_adds_physical_write_method(self):
+ ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n")
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ self.assertTrue(ir.module[0].type[0].structure.field[0].write_method.physical)
- def test_adds_read_only_write_method_to_non_alias_virtual(self):
- ir = self._make_ir("struct Foo:\n"
- " let x = 5\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- self.assertTrue(
- ir.module[0].type[0].structure.field[0].write_method.read_only)
+ def test_adds_read_only_write_method_to_non_alias_virtual(self):
+ ir = self._make_ir("struct Foo:\n" " let x = 5\n")
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ self.assertTrue(ir.module[0].type[0].structure.field[0].write_method.read_only)
- def test_adds_alias_write_method_to_alias_of_physical_field(self):
- ir = self._make_ir("struct Foo:\n"
- " let x = y\n"
- " 0 [+1] UInt y\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[0]
- self.assertTrue(field.write_method.HasField("alias"))
- self.assertEqual(
- "y", field.write_method.alias.path[0].canonical_name.object_path[-1])
+ def test_adds_alias_write_method_to_alias_of_physical_field(self):
+ ir = self._make_ir("struct Foo:\n" " let x = y\n" " 0 [+1] UInt y\n")
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[0]
+ self.assertTrue(field.write_method.HasField("alias"))
+ self.assertEqual(
+ "y", field.write_method.alias.path[0].canonical_name.object_path[-1]
+ )
- def test_adds_alias_write_method_to_alias_of_alias_of_physical_field(self):
- ir = self._make_ir("struct Foo:\n"
- " let x = z\n"
- " let z = y\n"
- " 0 [+1] UInt y\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[0]
- self.assertTrue(field.write_method.HasField("alias"))
- self.assertEqual(
- "z", field.write_method.alias.path[0].canonical_name.object_path[-1])
+ def test_adds_alias_write_method_to_alias_of_alias_of_physical_field(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " let x = z\n" " let z = y\n" " 0 [+1] UInt y\n"
+ )
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[0]
+ self.assertTrue(field.write_method.HasField("alias"))
+ self.assertEqual(
+ "z", field.write_method.alias.path[0].canonical_name.object_path[-1]
+ )
- def test_adds_read_only_write_method_to_alias_of_read_only(self):
- ir = self._make_ir("struct Foo:\n"
- " let x = y\n"
- " let y = 5\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[0]
- self.assertTrue(field.write_method.read_only)
+ def test_adds_read_only_write_method_to_alias_of_read_only(self):
+ ir = self._make_ir("struct Foo:\n" " let x = y\n" " let y = 5\n")
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[0]
+ self.assertTrue(field.write_method.read_only)
- def test_adds_read_only_write_method_to_alias_of_alias_of_read_only(self):
- ir = self._make_ir("struct Foo:\n"
- " let x = z\n"
- " let z = y\n"
- " let y = 5\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[0]
- self.assertTrue(field.write_method.read_only)
+ def test_adds_read_only_write_method_to_alias_of_alias_of_read_only(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " let x = z\n" " let z = y\n" " let y = 5\n"
+ )
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[0]
+ self.assertTrue(field.write_method.read_only)
- def test_adds_read_only_write_method_to_alias_of_parameter(self):
- ir = self._make_ir("struct Foo(x: UInt:8):\n"
- " let y = x\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[0]
- self.assertTrue(field.write_method.read_only)
+ def test_adds_read_only_write_method_to_alias_of_parameter(self):
+ ir = self._make_ir("struct Foo(x: UInt:8):\n" " let y = x\n")
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[0]
+ self.assertTrue(field.write_method.read_only)
- def test_adds_transform_write_method_to_base_value_field(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " let y = x + 50\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[1]
- transform = field.write_method.transform
- self.assertTrue(transform)
- self.assertEqual(
- "x",
- transform.destination.path[0].canonical_name.object_path[-1])
- self.assertEqual(ir_data.FunctionMapping.SUBTRACTION,
- transform.function_body.function.function)
- arg0, arg1 = transform.function_body.function.args
- self.assertEqual("$logical_value",
- arg0.builtin_reference.canonical_name.object_path[0])
- self.assertEqual("50", arg1.constant.value)
+ def test_adds_transform_write_method_to_base_value_field(self):
+ ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n" " let y = x + 50\n")
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[1]
+ transform = field.write_method.transform
+ self.assertTrue(transform)
+ self.assertEqual(
+ "x", transform.destination.path[0].canonical_name.object_path[-1]
+ )
+ self.assertEqual(
+ ir_data.FunctionMapping.SUBTRACTION,
+ transform.function_body.function.function,
+ )
+ arg0, arg1 = transform.function_body.function.args
+ self.assertEqual(
+ "$logical_value", arg0.builtin_reference.canonical_name.object_path[0]
+ )
+ self.assertEqual("50", arg1.constant.value)
- def test_adds_transform_write_method_to_negative_base_value_field(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " let y = x - 50\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[1]
- transform = field.write_method.transform
- self.assertTrue(transform)
- self.assertEqual(
- "x",
- transform.destination.path[0].canonical_name.object_path[-1])
- self.assertEqual(ir_data.FunctionMapping.ADDITION,
- transform.function_body.function.function)
- arg0, arg1 = transform.function_body.function.args
- self.assertEqual("$logical_value",
- arg0.builtin_reference.canonical_name.object_path[0])
- self.assertEqual("50", arg1.constant.value)
+ def test_adds_transform_write_method_to_negative_base_value_field(self):
+ ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n" " let y = x - 50\n")
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[1]
+ transform = field.write_method.transform
+ self.assertTrue(transform)
+ self.assertEqual(
+ "x", transform.destination.path[0].canonical_name.object_path[-1]
+ )
+ self.assertEqual(
+ ir_data.FunctionMapping.ADDITION, transform.function_body.function.function
+ )
+ arg0, arg1 = transform.function_body.function.args
+ self.assertEqual(
+ "$logical_value", arg0.builtin_reference.canonical_name.object_path[0]
+ )
+ self.assertEqual("50", arg1.constant.value)
- def test_adds_transform_write_method_to_reversed_base_value_field(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " let y = 50 + x\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[1]
- transform = field.write_method.transform
- self.assertTrue(transform)
- self.assertEqual(
- "x",
- transform.destination.path[0].canonical_name.object_path[-1])
- self.assertEqual(ir_data.FunctionMapping.SUBTRACTION,
- transform.function_body.function.function)
- arg0, arg1 = transform.function_body.function.args
- self.assertEqual("$logical_value",
- arg0.builtin_reference.canonical_name.object_path[0])
- self.assertEqual("50", arg1.constant.value)
+ def test_adds_transform_write_method_to_reversed_base_value_field(self):
+ ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n" " let y = 50 + x\n")
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[1]
+ transform = field.write_method.transform
+ self.assertTrue(transform)
+ self.assertEqual(
+ "x", transform.destination.path[0].canonical_name.object_path[-1]
+ )
+ self.assertEqual(
+ ir_data.FunctionMapping.SUBTRACTION,
+ transform.function_body.function.function,
+ )
+ arg0, arg1 = transform.function_body.function.args
+ self.assertEqual(
+ "$logical_value", arg0.builtin_reference.canonical_name.object_path[0]
+ )
+ self.assertEqual("50", arg1.constant.value)
- def test_adds_transform_write_method_to_reversed_negative_base_value_field(
- self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " let y = 50 - x\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[1]
- transform = field.write_method.transform
- self.assertTrue(transform)
- self.assertEqual(
- "x",
- transform.destination.path[0].canonical_name.object_path[-1])
- self.assertEqual(ir_data.FunctionMapping.SUBTRACTION,
- transform.function_body.function.function)
- arg0, arg1 = transform.function_body.function.args
- self.assertEqual("50", arg0.constant.value)
- self.assertEqual("$logical_value",
- arg1.builtin_reference.canonical_name.object_path[0])
+ def test_adds_transform_write_method_to_reversed_negative_base_value_field(self):
+ ir = self._make_ir("struct Foo:\n" " 0 [+1] UInt x\n" " let y = 50 - x\n")
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[1]
+ transform = field.write_method.transform
+ self.assertTrue(transform)
+ self.assertEqual(
+ "x", transform.destination.path[0].canonical_name.object_path[-1]
+ )
+ self.assertEqual(
+ ir_data.FunctionMapping.SUBTRACTION,
+ transform.function_body.function.function,
+ )
+ arg0, arg1 = transform.function_body.function.args
+ self.assertEqual("50", arg0.constant.value)
+ self.assertEqual(
+ "$logical_value", arg1.builtin_reference.canonical_name.object_path[0]
+ )
- def test_adds_transform_write_method_to_nested_invertible_field(self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " let y = 30 + (50 - x)\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[1]
- transform = field.write_method.transform
- self.assertTrue(transform)
- self.assertEqual(
- "x",
- transform.destination.path[0].canonical_name.object_path[-1])
- self.assertEqual(ir_data.FunctionMapping.SUBTRACTION,
- transform.function_body.function.function)
- arg0, arg1 = transform.function_body.function.args
- self.assertEqual("50", arg0.constant.value)
- self.assertEqual(ir_data.FunctionMapping.SUBTRACTION, arg1.function.function)
- arg10, arg11 = arg1.function.args
- self.assertEqual("$logical_value",
- arg10.builtin_reference.canonical_name.object_path[0])
- self.assertEqual("30", arg11.constant.value)
+ def test_adds_transform_write_method_to_nested_invertible_field(self):
+ ir = self._make_ir(
+ "struct Foo:\n" " 0 [+1] UInt x\n" " let y = 30 + (50 - x)\n"
+ )
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[1]
+ transform = field.write_method.transform
+ self.assertTrue(transform)
+ self.assertEqual(
+ "x", transform.destination.path[0].canonical_name.object_path[-1]
+ )
+ self.assertEqual(
+ ir_data.FunctionMapping.SUBTRACTION,
+ transform.function_body.function.function,
+ )
+ arg0, arg1 = transform.function_body.function.args
+ self.assertEqual("50", arg0.constant.value)
+ self.assertEqual(ir_data.FunctionMapping.SUBTRACTION, arg1.function.function)
+ arg10, arg11 = arg1.function.args
+ self.assertEqual(
+ "$logical_value", arg10.builtin_reference.canonical_name.object_path[0]
+ )
+ self.assertEqual("30", arg11.constant.value)
- def test_does_not_add_transform_write_method_for_parameter_target(self):
- ir = self._make_ir("struct Foo(x: UInt:8):\n"
- " let y = 50 + x\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[0]
- self.assertEqual("read_only", field.write_method.WhichOneof("method"))
+ def test_does_not_add_transform_write_method_for_parameter_target(self):
+ ir = self._make_ir("struct Foo(x: UInt:8):\n" " let y = 50 + x\n")
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[0]
+ self.assertEqual("read_only", field.write_method.WhichOneof("method"))
- def test_adds_transform_write_method_with_complex_auxiliary_subexpression(
- self):
- ir = self._make_ir("struct Foo:\n"
- " 0 [+1] UInt x\n"
- " let y = x - $max(Foo.$size_in_bytes, Foo.z)\n"
- " let z = 500\n")
- self.assertEqual([], write_inference.set_write_methods(ir))
- field = ir.module[0].type[0].structure.field[1]
- transform = field.write_method.transform
- self.assertTrue(transform)
- self.assertEqual(
- "x",
- transform.destination.path[0].canonical_name.object_path[-1])
- self.assertEqual(ir_data.FunctionMapping.ADDITION,
- transform.function_body.function.function)
- args = transform.function_body.function.args
- self.assertEqual("$logical_value",
- args[0].builtin_reference.canonical_name.object_path[0])
- self.assertEqual(field.read_transform.function.args[1], args[1])
+ def test_adds_transform_write_method_with_complex_auxiliary_subexpression(self):
+ ir = self._make_ir(
+ "struct Foo:\n"
+ " 0 [+1] UInt x\n"
+ " let y = x - $max(Foo.$size_in_bytes, Foo.z)\n"
+ " let z = 500\n"
+ )
+ self.assertEqual([], write_inference.set_write_methods(ir))
+ field = ir.module[0].type[0].structure.field[1]
+ transform = field.write_method.transform
+ self.assertTrue(transform)
+ self.assertEqual(
+ "x", transform.destination.path[0].canonical_name.object_path[-1]
+ )
+ self.assertEqual(
+ ir_data.FunctionMapping.ADDITION, transform.function_body.function.function
+ )
+ args = transform.function_body.function.args
+ self.assertEqual(
+ "$logical_value", args[0].builtin_reference.canonical_name.object_path[0]
+ )
+ self.assertEqual(field.read_transform.function.args[1], args[1])
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/util/attribute_util.py b/compiler/util/attribute_util.py
index 6e04280..a83cf51 100644
--- a/compiler/util/attribute_util.py
+++ b/compiler/util/attribute_util.py
@@ -31,61 +31,94 @@
def _attribute_name_for_errors(attr):
- if ir_data_utils.reader(attr).back_end.text:
- return f"({attr.back_end.text}) {attr.name.text}"
- else:
- return attr.name.text
+ if ir_data_utils.reader(attr).back_end.text:
+ return f"({attr.back_end.text}) {attr.name.text}"
+ else:
+ return attr.name.text
# Attribute type checkers
def _is_constant_boolean(attr, module_source_file):
- """Checks if the given attr is a constant boolean."""
- if not attr.value.expression.type.boolean.HasField("value"):
- return [[error.error(module_source_file,
- attr.value.source_location,
- _BAD_TYPE_MESSAGE.format(
- name=_attribute_name_for_errors(attr),
- type="a constant boolean"))]]
- return []
+ """Checks if the given attr is a constant boolean."""
+ if not attr.value.expression.type.boolean.HasField("value"):
+ return [
+ [
+ error.error(
+ module_source_file,
+ attr.value.source_location,
+ _BAD_TYPE_MESSAGE.format(
+ name=_attribute_name_for_errors(attr), type="a constant boolean"
+ ),
+ )
+ ]
+ ]
+ return []
def _is_boolean(attr, module_source_file):
- """Checks if the given attr is a boolean."""
- if attr.value.expression.type.WhichOneof("type") != "boolean":
- return [[error.error(module_source_file,
- attr.value.source_location,
- _BAD_TYPE_MESSAGE.format(
- name=_attribute_name_for_errors(attr),
- type="a boolean"))]]
- return []
+ """Checks if the given attr is a boolean."""
+ if attr.value.expression.type.WhichOneof("type") != "boolean":
+ return [
+ [
+ error.error(
+ module_source_file,
+ attr.value.source_location,
+ _BAD_TYPE_MESSAGE.format(
+ name=_attribute_name_for_errors(attr), type="a boolean"
+ ),
+ )
+ ]
+ ]
+ return []
def _is_constant_integer(attr, module_source_file):
- """Checks if the given attr is an integer constant expression."""
- if (not attr.value.HasField("expression") or
- attr.value.expression.type.WhichOneof("type") != "integer"):
- return [[error.error(module_source_file,
- attr.value.source_location,
- _BAD_TYPE_MESSAGE.format(
- name=_attribute_name_for_errors(attr),
- type="an integer"))]]
- if not ir_util.is_constant(attr.value.expression):
- return [[error.error(module_source_file,
- attr.value.source_location,
- _MUST_BE_CONSTANT_MESSAGE.format(
- name=_attribute_name_for_errors(attr)))]]
- return []
+ """Checks if the given attr is an integer constant expression."""
+ if (
+ not attr.value.HasField("expression")
+ or attr.value.expression.type.WhichOneof("type") != "integer"
+ ):
+ return [
+ [
+ error.error(
+ module_source_file,
+ attr.value.source_location,
+ _BAD_TYPE_MESSAGE.format(
+ name=_attribute_name_for_errors(attr), type="an integer"
+ ),
+ )
+ ]
+ ]
+ if not ir_util.is_constant(attr.value.expression):
+ return [
+ [
+ error.error(
+ module_source_file,
+ attr.value.source_location,
+ _MUST_BE_CONSTANT_MESSAGE.format(
+ name=_attribute_name_for_errors(attr)
+ ),
+ )
+ ]
+ ]
+ return []
def _is_string(attr, module_source_file):
- """Checks if the given attr is a string."""
- if not attr.value.HasField("string_constant"):
- return [[error.error(module_source_file,
- attr.value.source_location,
- _BAD_TYPE_MESSAGE.format(
- name=_attribute_name_for_errors(attr),
- type="a string"))]]
- return []
+ """Checks if the given attr is a string."""
+ if not attr.value.HasField("string_constant"):
+ return [
+ [
+ error.error(
+ module_source_file,
+ attr.value.source_location,
+ _BAD_TYPE_MESSAGE.format(
+ name=_attribute_name_for_errors(attr), type="a string"
+ ),
+ )
+ ]
+ ]
+ return []
# Provide more readable names for these functions when used in attribute type
@@ -97,215 +130,287 @@
def string_from_list(valid_values):
- """Checks if the given attr has one of the valid_values."""
- def _string_from_list(attr, module_source_file):
- if ir_data_utils.reader(attr).value.string_constant.text not in valid_values:
- return [[error.error(module_source_file,
- attr.value.source_location,
- "Attribute '{name}' must be '{options}'.".format(
- name=_attribute_name_for_errors(attr),
- options="' or '".join(sorted(valid_values))))]]
- return []
- return _string_from_list
+ """Checks if the given attr has one of the valid_values."""
+
+ def _string_from_list(attr, module_source_file):
+ if ir_data_utils.reader(attr).value.string_constant.text not in valid_values:
+ return [
+ [
+ error.error(
+ module_source_file,
+ attr.value.source_location,
+ "Attribute '{name}' must be '{options}'.".format(
+ name=_attribute_name_for_errors(attr),
+ options="' or '".join(sorted(valid_values)),
+ ),
+ )
+ ]
+ ]
+ return []
+
+ return _string_from_list
-def check_attributes_in_ir(ir,
- *,
- back_end=None,
- types=None,
- module_attributes=None,
- struct_attributes=None,
- bits_attributes=None,
- enum_attributes=None,
- enum_value_attributes=None,
- external_attributes=None,
- structure_virtual_field_attributes=None,
- structure_physical_field_attributes=None):
- """Performs basic checks on all attributes in the given ir.
+def check_attributes_in_ir(
+ ir,
+ *,
+ back_end=None,
+ types=None,
+ module_attributes=None,
+ struct_attributes=None,
+ bits_attributes=None,
+ enum_attributes=None,
+ enum_value_attributes=None,
+ external_attributes=None,
+ structure_virtual_field_attributes=None,
+ structure_physical_field_attributes=None,
+):
+ """Performs basic checks on all attributes in the given ir.
- This function calls _check_attributes on each attribute list in ir.
+ This function calls _check_attributes on each attribute list in ir.
- Arguments:
- ir: An ir_data.EmbossIr to check.
- back_end: A string specifying the attribute qualifier to check (such as
- `cpp` for `[(cpp) namespace = "foo"]`), or None to check unqualified
- attributes.
+ Arguments:
+ ir: An ir_data.EmbossIr to check.
+ back_end: A string specifying the attribute qualifier to check (such as
+ `cpp` for `[(cpp) namespace = "foo"]`), or None to check unqualified
+ attributes.
- Attributes with a different qualifier will not be checked.
- types: A map from attribute names to validators, such as:
- {
- "maximum_bits": attribute_util.INTEGER_CONSTANT,
- "requires": attribute_util.BOOLEAN,
- }
- module_attributes: A set of (attribute_name, is_default) tuples specifying
- the attributes that are allowed at module scope.
- struct_attributes: A set of (attribute_name, is_default) tuples specifying
- the attributes that are allowed at `struct` scope.
- bits_attributes: A set of (attribute_name, is_default) tuples specifying
- the attributes that are allowed at `bits` scope.
- enum_attributes: A set of (attribute_name, is_default) tuples specifying
- the attributes that are allowed at `enum` scope.
- enum_value_attributes: A set of (attribute_name, is_default) tuples
- specifying the attributes that are allowed at the scope of enum values.
- external_attributes: A set of (attribute_name, is_default) tuples
- specifying the attributes that are allowed at `external` scope.
- structure_virtual_field_attributes: A set of (attribute_name, is_default)
- tuples specifying the attributes that are allowed at the scope of
- virtual fields (`let` fields) in structures (both `struct` and `bits`).
- structure_physical_field_attributes: A set of (attribute_name, is_default)
- tuples specifying the attributes that are allowed at the scope of
- physical fields in structures (both `struct` and `bits`).
+ Attributes with a different qualifier will not be checked.
+ types: A map from attribute names to validators, such as:
+ {
+ "maximum_bits": attribute_util.INTEGER_CONSTANT,
+ "requires": attribute_util.BOOLEAN,
+ }
+ module_attributes: A set of (attribute_name, is_default) tuples specifying
+ the attributes that are allowed at module scope.
+ struct_attributes: A set of (attribute_name, is_default) tuples specifying
+ the attributes that are allowed at `struct` scope.
+ bits_attributes: A set of (attribute_name, is_default) tuples specifying
+ the attributes that are allowed at `bits` scope.
+ enum_attributes: A set of (attribute_name, is_default) tuples specifying
+ the attributes that are allowed at `enum` scope.
+ enum_value_attributes: A set of (attribute_name, is_default) tuples
+ specifying the attributes that are allowed at the scope of enum values.
+ external_attributes: A set of (attribute_name, is_default) tuples
+ specifying the attributes that are allowed at `external` scope.
+ structure_virtual_field_attributes: A set of (attribute_name, is_default)
+ tuples specifying the attributes that are allowed at the scope of
+ virtual fields (`let` fields) in structures (both `struct` and `bits`).
+ structure_physical_field_attributes: A set of (attribute_name, is_default)
+ tuples specifying the attributes that are allowed at the scope of
+ physical fields in structures (both `struct` and `bits`).
- Returns:
- A list of lists of error.error, or an empty list if there were no errors.
- """
+ Returns:
+ A list of lists of error.error, or an empty list if there were no errors.
+ """
- def check_module(module, errors):
- errors.extend(_check_attributes(
- module.attribute, types, back_end, module_attributes,
- "module '{}'".format(
- module.source_file_name), module.source_file_name))
+ def check_module(module, errors):
+ errors.extend(
+ _check_attributes(
+ module.attribute,
+ types,
+ back_end,
+ module_attributes,
+ "module '{}'".format(module.source_file_name),
+ module.source_file_name,
+ )
+ )
- def check_type_definition(type_definition, source_file_name, errors):
- if type_definition.HasField("structure"):
- if type_definition.addressable_unit == ir_data.AddressableUnit.BYTE:
- errors.extend(_check_attributes(
- type_definition.attribute, types, back_end, struct_attributes,
- "struct '{}'".format(
- type_definition.name.name.text), source_file_name))
- elif type_definition.addressable_unit == ir_data.AddressableUnit.BIT:
- errors.extend(_check_attributes(
- type_definition.attribute, types, back_end, bits_attributes,
- "bits '{}'".format(
- type_definition.name.name.text), source_file_name))
- else:
- assert False, "Unexpected addressable_unit '{}'".format(
- type_definition.addressable_unit)
- elif type_definition.HasField("enumeration"):
- errors.extend(_check_attributes(
- type_definition.attribute, types, back_end, enum_attributes,
- "enum '{}'".format(
- type_definition.name.name.text), source_file_name))
- elif type_definition.HasField("external"):
- errors.extend(_check_attributes(
- type_definition.attribute, types, back_end, external_attributes,
- "external '{}'".format(
- type_definition.name.name.text), source_file_name))
+ def check_type_definition(type_definition, source_file_name, errors):
+ if type_definition.HasField("structure"):
+ if type_definition.addressable_unit == ir_data.AddressableUnit.BYTE:
+ errors.extend(
+ _check_attributes(
+ type_definition.attribute,
+ types,
+ back_end,
+ struct_attributes,
+ "struct '{}'".format(type_definition.name.name.text),
+ source_file_name,
+ )
+ )
+ elif type_definition.addressable_unit == ir_data.AddressableUnit.BIT:
+ errors.extend(
+ _check_attributes(
+ type_definition.attribute,
+ types,
+ back_end,
+ bits_attributes,
+ "bits '{}'".format(type_definition.name.name.text),
+ source_file_name,
+ )
+ )
+ else:
+ assert False, "Unexpected addressable_unit '{}'".format(
+ type_definition.addressable_unit
+ )
+ elif type_definition.HasField("enumeration"):
+ errors.extend(
+ _check_attributes(
+ type_definition.attribute,
+ types,
+ back_end,
+ enum_attributes,
+ "enum '{}'".format(type_definition.name.name.text),
+ source_file_name,
+ )
+ )
+ elif type_definition.HasField("external"):
+ errors.extend(
+ _check_attributes(
+ type_definition.attribute,
+ types,
+ back_end,
+ external_attributes,
+ "external '{}'".format(type_definition.name.name.text),
+ source_file_name,
+ )
+ )
- def check_struct_field(field, source_file_name, errors):
- if ir_util.field_is_virtual(field):
- field_attributes = structure_virtual_field_attributes
- field_adjective = "virtual "
- else:
- field_attributes = structure_physical_field_attributes
- field_adjective = ""
- errors.extend(_check_attributes(
- field.attribute, types, back_end, field_attributes,
- "{}struct field '{}'".format(field_adjective, field.name.name.text),
- source_file_name))
+ def check_struct_field(field, source_file_name, errors):
+ if ir_util.field_is_virtual(field):
+ field_attributes = structure_virtual_field_attributes
+ field_adjective = "virtual "
+ else:
+ field_attributes = structure_physical_field_attributes
+ field_adjective = ""
+ errors.extend(
+ _check_attributes(
+ field.attribute,
+ types,
+ back_end,
+ field_attributes,
+ "{}struct field '{}'".format(field_adjective, field.name.name.text),
+ source_file_name,
+ )
+ )
- def check_enum_value(value, source_file_name, errors):
- errors.extend(_check_attributes(
- value.attribute, types, back_end, enum_value_attributes,
- "enum value '{}'".format(value.name.name.text), source_file_name))
+ def check_enum_value(value, source_file_name, errors):
+ errors.extend(
+ _check_attributes(
+ value.attribute,
+ types,
+ back_end,
+ enum_value_attributes,
+ "enum value '{}'".format(value.name.name.text),
+ source_file_name,
+ )
+ )
- errors = []
- # TODO(bolms): Add a check that only known $default'ed attributes are
- # used.
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Module], check_module,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.TypeDefinition], check_type_definition,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.Field], check_struct_field,
- parameters={"errors": errors})
- traverse_ir.fast_traverse_ir_top_down(
- ir, [ir_data.EnumValue], check_enum_value,
- parameters={"errors": errors})
- return errors
+ errors = []
+ # TODO(bolms): Add a check that only known $default'ed attributes are
+ # used.
+ traverse_ir.fast_traverse_ir_top_down(
+ ir, [ir_data.Module], check_module, parameters={"errors": errors}
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir,
+ [ir_data.TypeDefinition],
+ check_type_definition,
+ parameters={"errors": errors},
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir, [ir_data.Field], check_struct_field, parameters={"errors": errors}
+ )
+ traverse_ir.fast_traverse_ir_top_down(
+ ir, [ir_data.EnumValue], check_enum_value, parameters={"errors": errors}
+ )
+ return errors
-def _check_attributes(attribute_list, types, back_end, attribute_specs,
- context_name, module_source_file):
- """Performs basic checks on the given list of attributes.
+def _check_attributes(
+ attribute_list, types, back_end, attribute_specs, context_name, module_source_file
+):
+ """Performs basic checks on the given list of attributes.
- Checks the given attribute_list for duplicates, unknown attributes, attributes
- with incorrect type, and attributes whose values are not constant.
+ Checks the given attribute_list for duplicates, unknown attributes, attributes
+ with incorrect type, and attributes whose values are not constant.
- Arguments:
- attribute_list: An iterable of ir_data.Attribute.
- back_end: The qualifier for attributes to check, or None.
- attribute_specs: A dict of attribute names to _Attribute structures
- specifying the allowed attributes.
- context_name: A name for the context of these attributes, such as "struct
- 'Foo'" or "module 'm.emb'". Used in error messages.
- module_source_file: The value of module.source_file_name from the module
- containing 'attribute_list'. Used in error messages.
+ Arguments:
+ attribute_list: An iterable of ir_data.Attribute.
+ back_end: The qualifier for attributes to check, or None.
+ attribute_specs: A dict of attribute names to _Attribute structures
+ specifying the allowed attributes.
+ context_name: A name for the context of these attributes, such as "struct
+ 'Foo'" or "module 'm.emb'". Used in error messages.
+ module_source_file: The value of module.source_file_name from the module
+ containing 'attribute_list'. Used in error messages.
- Returns:
- A list of lists of error.Errors. An empty list indicates no errors were
- found.
- """
- if attribute_specs is None:
- attribute_specs = []
- errors = []
- already_seen_attributes = {}
- for attr in attribute_list:
- field_checker = ir_data_utils.reader(attr)
- if field_checker.back_end.text:
- if attr.back_end.text != back_end:
- continue
- else:
- if back_end is not None:
- continue
- attribute_name = _attribute_name_for_errors(attr)
- attr_key = (field_checker.name.text, field_checker.is_default)
- if attr_key in already_seen_attributes:
- original_attr = already_seen_attributes[attr_key]
- errors.append([
- error.error(module_source_file,
- attr.source_location,
- "Duplicate attribute '{}'.".format(attribute_name)),
- error.note(module_source_file,
- original_attr.source_location,
- "Original attribute")])
- continue
- already_seen_attributes[attr_key] = attr
+ Returns:
+ A list of lists of error.Errors. An empty list indicates no errors were
+ found.
+ """
+ if attribute_specs is None:
+ attribute_specs = []
+ errors = []
+ already_seen_attributes = {}
+ for attr in attribute_list:
+ field_checker = ir_data_utils.reader(attr)
+ if field_checker.back_end.text:
+ if attr.back_end.text != back_end:
+ continue
+ else:
+ if back_end is not None:
+ continue
+ attribute_name = _attribute_name_for_errors(attr)
+ attr_key = (field_checker.name.text, field_checker.is_default)
+ if attr_key in already_seen_attributes:
+ original_attr = already_seen_attributes[attr_key]
+ errors.append(
+ [
+ error.error(
+ module_source_file,
+ attr.source_location,
+ "Duplicate attribute '{}'.".format(attribute_name),
+ ),
+ error.note(
+ module_source_file,
+ original_attr.source_location,
+ "Original attribute",
+ ),
+ ]
+ )
+ continue
+ already_seen_attributes[attr_key] = attr
- if attr_key not in attribute_specs:
- if attr.is_default:
- error_message = "Attribute '{}' may not be defaulted on {}.".format(
- attribute_name, context_name)
- else:
- error_message = "Unknown attribute '{}' on {}.".format(attribute_name,
- context_name)
- errors.append([error.error(module_source_file,
- attr.name.source_location,
- error_message)])
- else:
- errors.extend(types[attr.name.text](attr, module_source_file))
- return errors
+ if attr_key not in attribute_specs:
+ if attr.is_default:
+ error_message = "Attribute '{}' may not be defaulted on {}.".format(
+ attribute_name, context_name
+ )
+ else:
+ error_message = "Unknown attribute '{}' on {}.".format(
+ attribute_name, context_name
+ )
+ errors.append(
+ [
+ error.error(
+ module_source_file, attr.name.source_location, error_message
+ )
+ ]
+ )
+ else:
+ errors.extend(types[attr.name.text](attr, module_source_file))
+ return errors
def gather_default_attributes(obj, defaults):
- """Gathers default attributes for an IR object
+ """Gathers default attributes for an IR object
- This is designed to be able to be used as-is as an incidental action in an IR
- traversal to accumulate defaults for child nodes.
+ This is designed to be able to be used as-is as an incidental action in an IR
+ traversal to accumulate defaults for child nodes.
- Arguments:
- defaults: A dict of `{ "defaults": { attr.name.text: attr } }`
+ Arguments:
+ defaults: A dict of `{ "defaults": { attr.name.text: attr } }`
- Returns:
- A dict of `{ "defaults": { attr.name.text: attr } }` with any defaults
- provided by `obj` added/overridden.
- """
- defaults = defaults.copy()
- for attr in obj.attribute:
- if attr.is_default:
- defaulted_attr = ir_data_utils.copy(attr)
- defaulted_attr.is_default = False
- defaults[attr.name.text] = defaulted_attr
- return {"defaults": defaults}
+ Returns:
+ A dict of `{ "defaults": { attr.name.text: attr } }` with any defaults
+ provided by `obj` added/overridden.
+ """
+ defaults = defaults.copy()
+ for attr in obj.attribute:
+ if attr.is_default:
+ defaulted_attr = ir_data_utils.copy(attr)
+ defaulted_attr.is_default = False
+ defaults[attr.name.text] = defaulted_attr
+ return {"defaults": defaults}
diff --git a/compiler/util/error.py b/compiler/util/error.py
index e408c71..a22fa4a 100644
--- a/compiler/util/error.py
+++ b/compiler/util/error.py
@@ -66,186 +66,199 @@
BOLD = "\033[0;1m"
RESET = "\033[0m"
+
def _copy(location):
- location = ir_data_utils.copy(location)
- if not location:
- location = parser_types.make_location((0,0), (0,0))
- return location
+ location = ir_data_utils.copy(location)
+ if not location:
+ location = parser_types.make_location((0, 0), (0, 0))
+ return location
def error(source_file, location, message):
- """Returns an object representing an error message."""
- return _Message(source_file, _copy(location), ERROR, message)
+ """Returns an object representing an error message."""
+ return _Message(source_file, _copy(location), ERROR, message)
def warn(source_file, location, message):
- """Returns an object representing a warning."""
- return _Message(source_file, _copy(location), WARNING, message)
+ """Returns an object representing a warning."""
+ return _Message(source_file, _copy(location), WARNING, message)
def note(source_file, location, message):
- """Returns and object representing an informational note."""
- return _Message(source_file, _copy(location), NOTE, message)
+ """Returns and object representing an informational note."""
+ return _Message(source_file, _copy(location), NOTE, message)
class _Message(object):
- """_Message holds a human-readable message."""
- __slots__ = ("location", "source_file", "severity", "message")
+ """_Message holds a human-readable message."""
- def __init__(self, source_file, location, severity, message):
- self.location = location
- self.source_file = source_file
- self.severity = severity
- self.message = message
+ __slots__ = ("location", "source_file", "severity", "message")
- def format(self, source_code):
- """Formats the _Message for display.
+ def __init__(self, source_file, location, severity, message):
+ self.location = location
+ self.source_file = source_file
+ self.severity = severity
+ self.message = message
- Arguments:
- source_code: A dict of file names to source texts. This is used to
- render source snippets.
+ def format(self, source_code):
+ """Formats the _Message for display.
- Returns:
- A list of tuples.
+ Arguments:
+ source_code: A dict of file names to source texts. This is used to
+ render source snippets.
- The first element of each tuple is an escape sequence used to put a Unix
- terminal into a particular color mode. For use in non-Unix-terminal
- output, the string will match one of the color names exported by this
- module.
+ Returns:
+ A list of tuples.
- The second element is a string containing text to show to the user.
+ The first element of each tuple is an escape sequence used to put a Unix
+ terminal into a particular color mode. For use in non-Unix-terminal
+ output, the string will match one of the color names exported by this
+ module.
- The text will not end with a newline character, nor will it include a
- RESET color element.
+ The second element is a string containing text to show to the user.
- To show non-colorized output, simply write the second element of each
- tuple, then a newline at the end.
+ The text will not end with a newline character, nor will it include a
+ RESET color element.
- To show colorized output, write both the first and second element of each
- tuple, then a newline at the end. Before exiting to the operating system,
- a RESET sequence should be emitted.
- """
- # TODO(bolms): Figure out how to get Vim, Emacs, etc. to parse Emboss error
- # messages.
- severity_colors = {
- ERROR: (BRIGHT_RED, BOLD),
- WARNING: (BRIGHT_MAGENTA, BOLD),
- NOTE: (BRIGHT_BLACK, WHITE)
- }
+ To show non-colorized output, simply write the second element of each
+ tuple, then a newline at the end.
- result = []
- if self.location.is_synthetic:
- pos = "[compiler bug]"
- else:
- pos = parser_types.format_position(self.location.start)
- source_name = self.source_file or "[prelude]"
- if not self.location.is_synthetic and self.source_file in source_code:
- source_lines = source_code[self.source_file].splitlines()
- source_line = source_lines[self.location.start.line - 1]
- else:
- source_line = ""
- lines = self.message.splitlines()
- for i in range(len(lines)):
- line = lines[i]
- # This is a little awkward, but we want to suppress the final newline in
- # the message. This newline is final if and only if it is the last line
- # of the message and there is no source snippet.
- if i != len(lines) - 1 or source_line:
- line += "\n"
- result.append((BOLD, "{}:{}: ".format(source_name, pos)))
- if i == 0:
- severity = self.severity
- else:
- severity = NOTE
- result.append((severity_colors[severity][0], "{}: ".format(severity)))
- result.append((severity_colors[severity][1], line))
- if source_line:
- result.append((WHITE, source_line + "\n"))
- indicator_indent = " " * (self.location.start.column - 1)
- if self.location.start.line == self.location.end.line:
- indicator_caret = "^" * max(
- 1, self.location.end.column - self.location.start.column)
- else:
- indicator_caret = "^"
- result.append((BRIGHT_GREEN, indicator_indent + indicator_caret))
- return result
+ To show colorized output, write both the first and second element of each
+ tuple, then a newline at the end. Before exiting to the operating system,
+ a RESET sequence should be emitted.
+ """
+ # TODO(bolms): Figure out how to get Vim, Emacs, etc. to parse Emboss error
+ # messages.
+ severity_colors = {
+ ERROR: (BRIGHT_RED, BOLD),
+ WARNING: (BRIGHT_MAGENTA, BOLD),
+ NOTE: (BRIGHT_BLACK, WHITE),
+ }
- def __repr__(self):
- return ("Message({source_file!r}, make_location(({start_line!r}, "
+ result = []
+ if self.location.is_synthetic:
+ pos = "[compiler bug]"
+ else:
+ pos = parser_types.format_position(self.location.start)
+ source_name = self.source_file or "[prelude]"
+ if not self.location.is_synthetic and self.source_file in source_code:
+ source_lines = source_code[self.source_file].splitlines()
+ source_line = source_lines[self.location.start.line - 1]
+ else:
+ source_line = ""
+ lines = self.message.splitlines()
+ for i in range(len(lines)):
+ line = lines[i]
+ # This is a little awkward, but we want to suppress the final newline in
+ # the message. This newline is final if and only if it is the last line
+ # of the message and there is no source snippet.
+ if i != len(lines) - 1 or source_line:
+ line += "\n"
+ result.append((BOLD, "{}:{}: ".format(source_name, pos)))
+ if i == 0:
+ severity = self.severity
+ else:
+ severity = NOTE
+ result.append((severity_colors[severity][0], "{}: ".format(severity)))
+ result.append((severity_colors[severity][1], line))
+ if source_line:
+ result.append((WHITE, source_line + "\n"))
+ indicator_indent = " " * (self.location.start.column - 1)
+ if self.location.start.line == self.location.end.line:
+ indicator_caret = "^" * max(
+ 1, self.location.end.column - self.location.start.column
+ )
+ else:
+ indicator_caret = "^"
+ result.append((BRIGHT_GREEN, indicator_indent + indicator_caret))
+ return result
+
+ def __repr__(self):
+ return (
+ "Message({source_file!r}, make_location(({start_line!r}, "
"{start_column!r}), ({end_line!r}, {end_column!r}), "
- "{is_synthetic!r}), {severity!r}, {message!r})").format(
- source_file=self.source_file,
- start_line=self.location.start.line,
- start_column=self.location.start.column,
- end_line=self.location.end.line,
- end_column=self.location.end.column,
- is_synthetic=self.location.is_synthetic,
- severity=self.severity,
- message=self.message)
+ "{is_synthetic!r}), {severity!r}, {message!r})"
+ ).format(
+ source_file=self.source_file,
+ start_line=self.location.start.line,
+ start_column=self.location.start.column,
+ end_line=self.location.end.line,
+ end_column=self.location.end.column,
+ is_synthetic=self.location.is_synthetic,
+ severity=self.severity,
+ message=self.message,
+ )
- def __eq__(self, other):
- return (
- self.__class__ == other.__class__ and self.location == other.location
- and self.source_file == other.source_file and
- self.severity == other.severity and self.message == other.message)
+ def __eq__(self, other):
+ return (
+ self.__class__ == other.__class__
+ and self.location == other.location
+ and self.source_file == other.source_file
+ and self.severity == other.severity
+ and self.message == other.message
+ )
- def __ne__(self, other):
- return not self == other
+ def __ne__(self, other):
+ return not self == other
def split_errors(errors):
- """Splits errors into (user_errors, synthetic_errors).
+ """Splits errors into (user_errors, synthetic_errors).
- Arguments:
- errors: A list of lists of _Message, which is a list of bundles of
- associated messages.
+ Arguments:
+ errors: A list of lists of _Message, which is a list of bundles of
+ associated messages.
- Returns:
- (user_errors, synthetic_errors), where both user_errors and
- synthetic_errors are lists of lists of _Message. synthetic_errors will
- contain all bundles that reference any synthetic source_location, and
- user_errors will contain the rest.
+ Returns:
+ (user_errors, synthetic_errors), where both user_errors and
+ synthetic_errors are lists of lists of _Message. synthetic_errors will
+ contain all bundles that reference any synthetic source_location, and
+ user_errors will contain the rest.
- The intent is that user_errors can be shown to end users, while
- synthetic_errors should generally be suppressed.
- """
- synthetic_errors = []
- user_errors = []
- for error_block in errors:
- if any(message.location.is_synthetic for message in error_block):
- synthetic_errors.append(error_block)
- else:
- user_errors.append(error_block)
- return user_errors, synthetic_errors
+ The intent is that user_errors can be shown to end users, while
+ synthetic_errors should generally be suppressed.
+ """
+ synthetic_errors = []
+ user_errors = []
+ for error_block in errors:
+ if any(message.location.is_synthetic for message in error_block):
+ synthetic_errors.append(error_block)
+ else:
+ user_errors.append(error_block)
+ return user_errors, synthetic_errors
def filter_errors(errors):
- """Returns the non-synthetic errors from `errors`."""
- return split_errors(errors)[0]
+ """Returns the non-synthetic errors from `errors`."""
+ return split_errors(errors)[0]
def format_errors(errors, source_codes, use_color=False):
- """Formats error messages with source code snippets."""
- result = []
- for error_group in errors:
- assert error_group, "Found empty error_group!"
- for message in error_group:
- if use_color:
- result.append("".join(e[0] + e[1] + RESET
- for e in message.format(source_codes)))
- else:
- result.append("".join(e[1] for e in message.format(source_codes)))
- return "\n".join(result)
+ """Formats error messages with source code snippets."""
+ result = []
+ for error_group in errors:
+ assert error_group, "Found empty error_group!"
+ for message in error_group:
+ if use_color:
+ result.append(
+ "".join(e[0] + e[1] + RESET for e in message.format(source_codes))
+ )
+ else:
+ result.append("".join(e[1] for e in message.format(source_codes)))
+ return "\n".join(result)
def make_error_from_parse_error(file_name, parse_error):
- return [error(file_name,
- parse_error.token.source_location,
- "{code}\n"
- "Found {text!r} ({symbol}), expected {expected}.".format(
- code=parse_error.code or "Syntax error",
- text=parse_error.token.text,
- symbol=parse_error.token.symbol,
- expected=", ".join(parse_error.expected_tokens)))]
-
-
+ return [
+ error(
+ file_name,
+ parse_error.token.source_location,
+ "{code}\n"
+ "Found {text!r} ({symbol}), expected {expected}.".format(
+ code=parse_error.code or "Syntax error",
+ text=parse_error.token.text,
+ symbol=parse_error.token.symbol,
+ expected=", ".join(parse_error.expected_tokens),
+ ),
+ )
+ ]
diff --git a/compiler/util/error_test.py b/compiler/util/error_test.py
index 7d2577f..23beddd 100644
--- a/compiler/util/error_test.py
+++ b/compiler/util/error_test.py
@@ -21,325 +21,463 @@
class MessageTest(unittest.TestCase):
- """Tests for _Message, as returned by error, warn, and note."""
+ """Tests for _Message, as returned by error, warn, and note."""
- def test_error(self):
- error_message = error.error("foo.emb", parser_types.make_location(
- (3, 4), (3, 6)), "Bad thing")
- self.assertEqual("foo.emb", error_message.source_file)
- self.assertEqual(error.ERROR, error_message.severity)
- self.assertEqual(parser_types.make_location((3, 4), (3, 6)),
- error_message.location)
- self.assertEqual("Bad thing", error_message.message)
- sourceless_format = error_message.format({})
- sourced_format = error_message.format({"foo.emb": "\n\nabcdefghijklm"})
- self.assertEqual("foo.emb:3:4: error: Bad thing",
- "".join([x[1] for x in sourceless_format]))
- self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location
- (error.BRIGHT_RED, "error: "), # Severity
- (error.BOLD, "Bad thing"), # Message
- ], sourceless_format)
- self.assertEqual("foo.emb:3:4: error: Bad thing\n"
- "abcdefghijklm\n"
- " ^^", "".join([x[1] for x in sourced_format]))
- self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location
- (error.BRIGHT_RED, "error: "), # Severity
- (error.BOLD, "Bad thing\n"), # Message
- (error.WHITE, "abcdefghijklm\n"), # Source snippet
- (error.BRIGHT_GREEN, " ^^"), # Error column indicator
- ], sourced_format)
+ def test_error(self):
+ error_message = error.error(
+ "foo.emb", parser_types.make_location((3, 4), (3, 6)), "Bad thing"
+ )
+ self.assertEqual("foo.emb", error_message.source_file)
+ self.assertEqual(error.ERROR, error_message.severity)
+ self.assertEqual(
+ parser_types.make_location((3, 4), (3, 6)), error_message.location
+ )
+ self.assertEqual("Bad thing", error_message.message)
+ sourceless_format = error_message.format({})
+ sourced_format = error_message.format({"foo.emb": "\n\nabcdefghijklm"})
+ self.assertEqual(
+ "foo.emb:3:4: error: Bad thing", "".join([x[1] for x in sourceless_format])
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "foo.emb:3:4: "), # Location
+ (error.BRIGHT_RED, "error: "), # Severity
+ (error.BOLD, "Bad thing"), # Message
+ ],
+ sourceless_format,
+ )
+ self.assertEqual(
+ "foo.emb:3:4: error: Bad thing\n" "abcdefghijklm\n" " ^^",
+ "".join([x[1] for x in sourced_format]),
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "foo.emb:3:4: "), # Location
+ (error.BRIGHT_RED, "error: "), # Severity
+ (error.BOLD, "Bad thing\n"), # Message
+ (error.WHITE, "abcdefghijklm\n"), # Source snippet
+ (error.BRIGHT_GREEN, " ^^"), # Error column indicator
+ ],
+ sourced_format,
+ )
- def test_synthetic_error(self):
- error_message = error.error("foo.emb", parser_types.make_location(
- (3, 4), (3, 6), True), "Bad thing")
- sourceless_format = error_message.format({})
- sourced_format = error_message.format({"foo.emb": "\n\nabcdefghijklm"})
- self.assertEqual("foo.emb:[compiler bug]: error: Bad thing",
- "".join([x[1] for x in sourceless_format]))
- self.assertEqual([
- (error.BOLD, "foo.emb:[compiler bug]: "), # Location
- (error.BRIGHT_RED, "error: "), # Severity
- (error.BOLD, "Bad thing"), # Message
- ], sourceless_format)
- self.assertEqual("foo.emb:[compiler bug]: error: Bad thing",
- "".join([x[1] for x in sourced_format]))
- self.assertEqual([
- (error.BOLD, "foo.emb:[compiler bug]: "), # Location
- (error.BRIGHT_RED, "error: "), # Severity
- (error.BOLD, "Bad thing"), # Message
- ], sourced_format)
+ def test_synthetic_error(self):
+ error_message = error.error(
+ "foo.emb", parser_types.make_location((3, 4), (3, 6), True), "Bad thing"
+ )
+ sourceless_format = error_message.format({})
+ sourced_format = error_message.format({"foo.emb": "\n\nabcdefghijklm"})
+ self.assertEqual(
+ "foo.emb:[compiler bug]: error: Bad thing",
+ "".join([x[1] for x in sourceless_format]),
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "foo.emb:[compiler bug]: "), # Location
+ (error.BRIGHT_RED, "error: "), # Severity
+ (error.BOLD, "Bad thing"), # Message
+ ],
+ sourceless_format,
+ )
+ self.assertEqual(
+ "foo.emb:[compiler bug]: error: Bad thing",
+ "".join([x[1] for x in sourced_format]),
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "foo.emb:[compiler bug]: "), # Location
+ (error.BRIGHT_RED, "error: "), # Severity
+ (error.BOLD, "Bad thing"), # Message
+ ],
+ sourced_format,
+ )
- def test_prelude_as_file_name(self):
- error_message = error.error("", parser_types.make_location(
- (3, 4), (3, 6)), "Bad thing")
- self.assertEqual("", error_message.source_file)
- self.assertEqual(error.ERROR, error_message.severity)
- self.assertEqual(parser_types.make_location((3, 4), (3, 6)),
- error_message.location)
- self.assertEqual("Bad thing", error_message.message)
- sourceless_format = error_message.format({})
- sourced_format = error_message.format({"": "\n\nabcdefghijklm"})
- self.assertEqual("[prelude]:3:4: error: Bad thing",
- "".join([x[1] for x in sourceless_format]))
- self.assertEqual([(error.BOLD, "[prelude]:3:4: "), # Location
- (error.BRIGHT_RED, "error: "), # Severity
- (error.BOLD, "Bad thing"), # Message
- ], sourceless_format)
- self.assertEqual("[prelude]:3:4: error: Bad thing\n"
- "abcdefghijklm\n"
- " ^^", "".join([x[1] for x in sourced_format]))
- self.assertEqual([(error.BOLD, "[prelude]:3:4: "), # Location
- (error.BRIGHT_RED, "error: "), # Severity
- (error.BOLD, "Bad thing\n"), # Message
- (error.WHITE, "abcdefghijklm\n"), # Source snippet
- (error.BRIGHT_GREEN, " ^^"), # Error column indicator
- ], sourced_format)
+ def test_prelude_as_file_name(self):
+ error_message = error.error(
+ "", parser_types.make_location((3, 4), (3, 6)), "Bad thing"
+ )
+ self.assertEqual("", error_message.source_file)
+ self.assertEqual(error.ERROR, error_message.severity)
+ self.assertEqual(
+ parser_types.make_location((3, 4), (3, 6)), error_message.location
+ )
+ self.assertEqual("Bad thing", error_message.message)
+ sourceless_format = error_message.format({})
+ sourced_format = error_message.format({"": "\n\nabcdefghijklm"})
+ self.assertEqual(
+ "[prelude]:3:4: error: Bad thing",
+ "".join([x[1] for x in sourceless_format]),
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "[prelude]:3:4: "), # Location
+ (error.BRIGHT_RED, "error: "), # Severity
+ (error.BOLD, "Bad thing"), # Message
+ ],
+ sourceless_format,
+ )
+ self.assertEqual(
+ "[prelude]:3:4: error: Bad thing\n" "abcdefghijklm\n" " ^^",
+ "".join([x[1] for x in sourced_format]),
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "[prelude]:3:4: "), # Location
+ (error.BRIGHT_RED, "error: "), # Severity
+ (error.BOLD, "Bad thing\n"), # Message
+ (error.WHITE, "abcdefghijklm\n"), # Source snippet
+ (error.BRIGHT_GREEN, " ^^"), # Error column indicator
+ ],
+ sourced_format,
+ )
- def test_multiline_error_source(self):
- error_message = error.error("foo.emb", parser_types.make_location(
- (3, 4), (4, 6)), "Bad thing")
- self.assertEqual("foo.emb", error_message.source_file)
- self.assertEqual(error.ERROR, error_message.severity)
- self.assertEqual(parser_types.make_location((3, 4), (4, 6)),
- error_message.location)
- self.assertEqual("Bad thing", error_message.message)
- sourceless_format = error_message.format({})
- sourced_format = error_message.format(
- {"foo.emb": "\n\nabcdefghijklm\nnopqrstuv"})
- self.assertEqual("foo.emb:3:4: error: Bad thing",
- "".join([x[1] for x in sourceless_format]))
- self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location
- (error.BRIGHT_RED, "error: "), # Severity
- (error.BOLD, "Bad thing"), # Message
- ], sourceless_format)
- self.assertEqual("foo.emb:3:4: error: Bad thing\n"
- "abcdefghijklm\n"
- " ^", "".join([x[1] for x in sourced_format]))
- self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location
- (error.BRIGHT_RED, "error: "), # Severity
- (error.BOLD, "Bad thing\n"), # Message
- (error.WHITE, "abcdefghijklm\n"), # Source snippet
- (error.BRIGHT_GREEN, " ^"), # Error column indicator
- ], sourced_format)
+ def test_multiline_error_source(self):
+ error_message = error.error(
+ "foo.emb", parser_types.make_location((3, 4), (4, 6)), "Bad thing"
+ )
+ self.assertEqual("foo.emb", error_message.source_file)
+ self.assertEqual(error.ERROR, error_message.severity)
+ self.assertEqual(
+ parser_types.make_location((3, 4), (4, 6)), error_message.location
+ )
+ self.assertEqual("Bad thing", error_message.message)
+ sourceless_format = error_message.format({})
+ sourced_format = error_message.format(
+ {"foo.emb": "\n\nabcdefghijklm\nnopqrstuv"}
+ )
+ self.assertEqual(
+ "foo.emb:3:4: error: Bad thing", "".join([x[1] for x in sourceless_format])
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "foo.emb:3:4: "), # Location
+ (error.BRIGHT_RED, "error: "), # Severity
+ (error.BOLD, "Bad thing"), # Message
+ ],
+ sourceless_format,
+ )
+ self.assertEqual(
+ "foo.emb:3:4: error: Bad thing\n" "abcdefghijklm\n" " ^",
+ "".join([x[1] for x in sourced_format]),
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "foo.emb:3:4: "), # Location
+ (error.BRIGHT_RED, "error: "), # Severity
+ (error.BOLD, "Bad thing\n"), # Message
+ (error.WHITE, "abcdefghijklm\n"), # Source snippet
+ (error.BRIGHT_GREEN, " ^"), # Error column indicator
+ ],
+ sourced_format,
+ )
- def test_multiline_error(self):
- error_message = error.error("foo.emb", parser_types.make_location(
- (3, 4), (3, 6)), "Bad thing\nSome explanation\nMore explanation")
- self.assertEqual("foo.emb", error_message.source_file)
- self.assertEqual(error.ERROR, error_message.severity)
- self.assertEqual(parser_types.make_location((3, 4), (3, 6)),
- error_message.location)
- self.assertEqual("Bad thing\nSome explanation\nMore explanation",
- error_message.message)
- sourceless_format = error_message.format({})
- sourced_format = error_message.format(
- {"foo.emb": "\n\nabcdefghijklm\nnopqrstuv"})
- self.assertEqual("foo.emb:3:4: error: Bad thing\n"
- "foo.emb:3:4: note: Some explanation\n"
- "foo.emb:3:4: note: More explanation",
- "".join([x[1] for x in sourceless_format]))
- self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location
- (error.BRIGHT_RED, "error: "), # Severity
- (error.BOLD, "Bad thing\n"), # Message
- (error.BOLD, "foo.emb:3:4: "), # Location, line 2
- (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 2
- (error.WHITE, "Some explanation\n"), # Message, line 2
- (error.BOLD, "foo.emb:3:4: "), # Location, line 3
- (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 3
- (error.WHITE, "More explanation"), # Message, line 3
- ], sourceless_format)
- self.assertEqual("foo.emb:3:4: error: Bad thing\n"
- "foo.emb:3:4: note: Some explanation\n"
- "foo.emb:3:4: note: More explanation\n"
- "abcdefghijklm\n"
- " ^^", "".join([x[1] for x in sourced_format]))
- self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location
- (error.BRIGHT_RED, "error: "), # Severity
- (error.BOLD, "Bad thing\n"), # Message
- (error.BOLD, "foo.emb:3:4: "), # Location, line 2
- (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 2
- (error.WHITE, "Some explanation\n"), # Message, line 2
- (error.BOLD, "foo.emb:3:4: "), # Location, line 3
- (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 3
- (error.WHITE, "More explanation\n"), # Message, line 3
- (error.WHITE, "abcdefghijklm\n"), # Source snippet
- (error.BRIGHT_GREEN, " ^^"), # Column indicator
- ], sourced_format)
+ def test_multiline_error(self):
+ error_message = error.error(
+ "foo.emb",
+ parser_types.make_location((3, 4), (3, 6)),
+ "Bad thing\nSome explanation\nMore explanation",
+ )
+ self.assertEqual("foo.emb", error_message.source_file)
+ self.assertEqual(error.ERROR, error_message.severity)
+ self.assertEqual(
+ parser_types.make_location((3, 4), (3, 6)), error_message.location
+ )
+ self.assertEqual(
+ "Bad thing\nSome explanation\nMore explanation", error_message.message
+ )
+ sourceless_format = error_message.format({})
+ sourced_format = error_message.format(
+ {"foo.emb": "\n\nabcdefghijklm\nnopqrstuv"}
+ )
+ self.assertEqual(
+ "foo.emb:3:4: error: Bad thing\n"
+ "foo.emb:3:4: note: Some explanation\n"
+ "foo.emb:3:4: note: More explanation",
+ "".join([x[1] for x in sourceless_format]),
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "foo.emb:3:4: "), # Location
+ (error.BRIGHT_RED, "error: "), # Severity
+ (error.BOLD, "Bad thing\n"), # Message
+ (error.BOLD, "foo.emb:3:4: "), # Location, line 2
+ (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 2
+ (error.WHITE, "Some explanation\n"), # Message, line 2
+ (error.BOLD, "foo.emb:3:4: "), # Location, line 3
+ (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 3
+ (error.WHITE, "More explanation"), # Message, line 3
+ ],
+ sourceless_format,
+ )
+ self.assertEqual(
+ "foo.emb:3:4: error: Bad thing\n"
+ "foo.emb:3:4: note: Some explanation\n"
+ "foo.emb:3:4: note: More explanation\n"
+ "abcdefghijklm\n"
+ " ^^",
+ "".join([x[1] for x in sourced_format]),
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "foo.emb:3:4: "), # Location
+ (error.BRIGHT_RED, "error: "), # Severity
+ (error.BOLD, "Bad thing\n"), # Message
+ (error.BOLD, "foo.emb:3:4: "), # Location, line 2
+ (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 2
+ (error.WHITE, "Some explanation\n"), # Message, line 2
+ (error.BOLD, "foo.emb:3:4: "), # Location, line 3
+ (error.BRIGHT_BLACK, "note: "), # "Note" severity, line 3
+ (error.WHITE, "More explanation\n"), # Message, line 3
+ (error.WHITE, "abcdefghijklm\n"), # Source snippet
+ (error.BRIGHT_GREEN, " ^^"), # Column indicator
+ ],
+ sourced_format,
+ )
- def test_warn(self):
- warning_message = error.warn("foo.emb", parser_types.make_location(
- (3, 4), (3, 6)), "Not good thing")
- self.assertEqual("foo.emb", warning_message.source_file)
- self.assertEqual(error.WARNING, warning_message.severity)
- self.assertEqual(parser_types.make_location((3, 4), (3, 6)),
- warning_message.location)
- self.assertEqual("Not good thing", warning_message.message)
- sourced_format = warning_message.format({"foo.emb": "\n\nabcdefghijklm"})
- self.assertEqual("foo.emb:3:4: warning: Not good thing\n"
- "abcdefghijklm\n"
- " ^^", "".join([x[1] for x in sourced_format]))
- self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location
- (error.BRIGHT_MAGENTA, "warning: "), # Severity
- (error.BOLD, "Not good thing\n"), # Message
- (error.WHITE, "abcdefghijklm\n"), # Source snippet
- (error.BRIGHT_GREEN, " ^^"), # Column indicator
- ], sourced_format)
+ def test_warn(self):
+ warning_message = error.warn(
+ "foo.emb", parser_types.make_location((3, 4), (3, 6)), "Not good thing"
+ )
+ self.assertEqual("foo.emb", warning_message.source_file)
+ self.assertEqual(error.WARNING, warning_message.severity)
+ self.assertEqual(
+ parser_types.make_location((3, 4), (3, 6)), warning_message.location
+ )
+ self.assertEqual("Not good thing", warning_message.message)
+ sourced_format = warning_message.format({"foo.emb": "\n\nabcdefghijklm"})
+ self.assertEqual(
+ "foo.emb:3:4: warning: Not good thing\n" "abcdefghijklm\n" " ^^",
+ "".join([x[1] for x in sourced_format]),
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "foo.emb:3:4: "), # Location
+ (error.BRIGHT_MAGENTA, "warning: "), # Severity
+ (error.BOLD, "Not good thing\n"), # Message
+ (error.WHITE, "abcdefghijklm\n"), # Source snippet
+ (error.BRIGHT_GREEN, " ^^"), # Column indicator
+ ],
+ sourced_format,
+ )
- def test_note(self):
- note_message = error.note("foo.emb", parser_types.make_location(
- (3, 4), (3, 6)), "OK thing")
- self.assertEqual("foo.emb", note_message.source_file)
- self.assertEqual(error.NOTE, note_message.severity)
- self.assertEqual(parser_types.make_location((3, 4), (3, 6)),
- note_message.location)
- self.assertEqual("OK thing", note_message.message)
- sourced_format = note_message.format({"foo.emb": "\n\nabcdefghijklm"})
- self.assertEqual("foo.emb:3:4: note: OK thing\n"
- "abcdefghijklm\n"
- " ^^", "".join([x[1] for x in sourced_format]))
- self.assertEqual([(error.BOLD, "foo.emb:3:4: "), # Location
- (error.BRIGHT_BLACK, "note: "), # Severity
- (error.WHITE, "OK thing\n"), # Message
- (error.WHITE, "abcdefghijklm\n"), # Source snippet
- (error.BRIGHT_GREEN, " ^^"), # Column indicator
- ], sourced_format)
+ def test_note(self):
+ note_message = error.note(
+ "foo.emb", parser_types.make_location((3, 4), (3, 6)), "OK thing"
+ )
+ self.assertEqual("foo.emb", note_message.source_file)
+ self.assertEqual(error.NOTE, note_message.severity)
+ self.assertEqual(
+ parser_types.make_location((3, 4), (3, 6)), note_message.location
+ )
+ self.assertEqual("OK thing", note_message.message)
+ sourced_format = note_message.format({"foo.emb": "\n\nabcdefghijklm"})
+ self.assertEqual(
+ "foo.emb:3:4: note: OK thing\n" "abcdefghijklm\n" " ^^",
+ "".join([x[1] for x in sourced_format]),
+ )
+ self.assertEqual(
+ [
+ (error.BOLD, "foo.emb:3:4: "), # Location
+ (error.BRIGHT_BLACK, "note: "), # Severity
+ (error.WHITE, "OK thing\n"), # Message
+ (error.WHITE, "abcdefghijklm\n"), # Source snippet
+ (error.BRIGHT_GREEN, " ^^"), # Column indicator
+ ],
+ sourced_format,
+ )
- def test_equality(self):
- note_message = error.note("foo.emb", parser_types.make_location(
- (3, 4), (3, 6)), "thing")
- self.assertEqual(note_message,
- error.note("foo.emb", parser_types.make_location(
- (3, 4), (3, 6)), "thing"))
- self.assertNotEqual(note_message,
- error.warn("foo.emb", parser_types.make_location(
- (3, 4), (3, 6)), "thing"))
- self.assertNotEqual(note_message,
- error.note("foo2.emb", parser_types.make_location(
- (3, 4), (3, 6)), "thing"))
- self.assertNotEqual(note_message,
- error.note("foo.emb", parser_types.make_location(
- (2, 4), (3, 6)), "thing"))
- self.assertNotEqual(note_message,
- error.note("foo.emb", parser_types.make_location(
- (3, 4), (3, 6)), "thing2"))
+ def test_equality(self):
+ note_message = error.note(
+ "foo.emb", parser_types.make_location((3, 4), (3, 6)), "thing"
+ )
+ self.assertEqual(
+ note_message,
+ error.note("foo.emb", parser_types.make_location((3, 4), (3, 6)), "thing"),
+ )
+ self.assertNotEqual(
+ note_message,
+ error.warn("foo.emb", parser_types.make_location((3, 4), (3, 6)), "thing"),
+ )
+ self.assertNotEqual(
+ note_message,
+ error.note("foo2.emb", parser_types.make_location((3, 4), (3, 6)), "thing"),
+ )
+ self.assertNotEqual(
+ note_message,
+ error.note("foo.emb", parser_types.make_location((2, 4), (3, 6)), "thing"),
+ )
+ self.assertNotEqual(
+ note_message,
+ error.note("foo.emb", parser_types.make_location((3, 4), (3, 6)), "thing2"),
+ )
class StringTest(unittest.TestCase):
- """Tests for strings."""
+ """Tests for strings."""
- # These strings are a fixed part of the API.
+ # These strings are a fixed part of the API.
- def test_color_strings(self):
- self.assertEqual("\033[0;30m", error.BLACK)
- self.assertEqual("\033[0;31m", error.RED)
- self.assertEqual("\033[0;32m", error.GREEN)
- self.assertEqual("\033[0;33m", error.YELLOW)
- self.assertEqual("\033[0;34m", error.BLUE)
- self.assertEqual("\033[0;35m", error.MAGENTA)
- self.assertEqual("\033[0;36m", error.CYAN)
- self.assertEqual("\033[0;37m", error.WHITE)
- self.assertEqual("\033[0;1;30m", error.BRIGHT_BLACK)
- self.assertEqual("\033[0;1;31m", error.BRIGHT_RED)
- self.assertEqual("\033[0;1;32m", error.BRIGHT_GREEN)
- self.assertEqual("\033[0;1;33m", error.BRIGHT_YELLOW)
- self.assertEqual("\033[0;1;34m", error.BRIGHT_BLUE)
- self.assertEqual("\033[0;1;35m", error.BRIGHT_MAGENTA)
- self.assertEqual("\033[0;1;36m", error.BRIGHT_CYAN)
- self.assertEqual("\033[0;1;37m", error.BRIGHT_WHITE)
- self.assertEqual("\033[0;1m", error.BOLD)
- self.assertEqual("\033[0m", error.RESET)
+ def test_color_strings(self):
+ self.assertEqual("\033[0;30m", error.BLACK)
+ self.assertEqual("\033[0;31m", error.RED)
+ self.assertEqual("\033[0;32m", error.GREEN)
+ self.assertEqual("\033[0;33m", error.YELLOW)
+ self.assertEqual("\033[0;34m", error.BLUE)
+ self.assertEqual("\033[0;35m", error.MAGENTA)
+ self.assertEqual("\033[0;36m", error.CYAN)
+ self.assertEqual("\033[0;37m", error.WHITE)
+ self.assertEqual("\033[0;1;30m", error.BRIGHT_BLACK)
+ self.assertEqual("\033[0;1;31m", error.BRIGHT_RED)
+ self.assertEqual("\033[0;1;32m", error.BRIGHT_GREEN)
+ self.assertEqual("\033[0;1;33m", error.BRIGHT_YELLOW)
+ self.assertEqual("\033[0;1;34m", error.BRIGHT_BLUE)
+ self.assertEqual("\033[0;1;35m", error.BRIGHT_MAGENTA)
+ self.assertEqual("\033[0;1;36m", error.BRIGHT_CYAN)
+ self.assertEqual("\033[0;1;37m", error.BRIGHT_WHITE)
+ self.assertEqual("\033[0;1m", error.BOLD)
+ self.assertEqual("\033[0m", error.RESET)
- def test_error_strings(self):
- self.assertEqual("error", error.ERROR)
- self.assertEqual("warning", error.WARNING)
- self.assertEqual("note", error.NOTE)
+ def test_error_strings(self):
+ self.assertEqual("error", error.ERROR)
+ self.assertEqual("warning", error.WARNING)
+ self.assertEqual("note", error.NOTE)
class SplitErrorsTest(unittest.TestCase):
- def test_split_errors(self):
- user_error = [
- error.error("foo.emb", parser_types.make_location((1, 2), (3, 4)),
- "Bad thing"),
- error.note("foo.emb", parser_types.make_location((3, 4), (5, 6)),
- "Note: bad thing referrent")
- ]
- user_error_2 = [
- error.error("foo.emb", parser_types.make_location((8, 9), (10, 11)),
- "Bad thing"),
- error.note("foo.emb", parser_types.make_location((10, 11), (12, 13)),
- "Note: bad thing referrent")
- ]
- synthetic_error = [
- error.error("foo.emb", parser_types.make_location((1, 2), (3, 4)),
- "Bad thing"),
- error.note("foo.emb", parser_types.make_location((3, 4), (5, 6), True),
- "Note: bad thing referrent")
- ]
- synthetic_error_2 = [
- error.error("foo.emb",
- parser_types.make_location((8, 9), (10, 11), True),
- "Bad thing"),
- error.note("foo.emb", parser_types.make_location((10, 11), (12, 13)),
- "Note: bad thing referrent")
- ]
- user_errors, synthetic_errors = error.split_errors(
- [user_error, synthetic_error])
- self.assertEqual([user_error], user_errors)
- self.assertEqual([synthetic_error], synthetic_errors)
- user_errors, synthetic_errors = error.split_errors(
- [synthetic_error, user_error])
- self.assertEqual([user_error], user_errors)
- self.assertEqual([synthetic_error], synthetic_errors)
- user_errors, synthetic_errors = error.split_errors(
- [synthetic_error, user_error, synthetic_error_2, user_error_2])
- self.assertEqual([user_error, user_error_2], user_errors)
- self.assertEqual([synthetic_error, synthetic_error_2], synthetic_errors)
+ def test_split_errors(self):
+ user_error = [
+ error.error(
+ "foo.emb", parser_types.make_location((1, 2), (3, 4)), "Bad thing"
+ ),
+ error.note(
+ "foo.emb",
+ parser_types.make_location((3, 4), (5, 6)),
+ "Note: bad thing referrent",
+ ),
+ ]
+ user_error_2 = [
+ error.error(
+ "foo.emb", parser_types.make_location((8, 9), (10, 11)), "Bad thing"
+ ),
+ error.note(
+ "foo.emb",
+ parser_types.make_location((10, 11), (12, 13)),
+ "Note: bad thing referrent",
+ ),
+ ]
+ synthetic_error = [
+ error.error(
+ "foo.emb", parser_types.make_location((1, 2), (3, 4)), "Bad thing"
+ ),
+ error.note(
+ "foo.emb",
+ parser_types.make_location((3, 4), (5, 6), True),
+ "Note: bad thing referrent",
+ ),
+ ]
+ synthetic_error_2 = [
+ error.error(
+ "foo.emb",
+ parser_types.make_location((8, 9), (10, 11), True),
+ "Bad thing",
+ ),
+ error.note(
+ "foo.emb",
+ parser_types.make_location((10, 11), (12, 13)),
+ "Note: bad thing referrent",
+ ),
+ ]
+ user_errors, synthetic_errors = error.split_errors(
+ [user_error, synthetic_error]
+ )
+ self.assertEqual([user_error], user_errors)
+ self.assertEqual([synthetic_error], synthetic_errors)
+ user_errors, synthetic_errors = error.split_errors(
+ [synthetic_error, user_error]
+ )
+ self.assertEqual([user_error], user_errors)
+ self.assertEqual([synthetic_error], synthetic_errors)
+ user_errors, synthetic_errors = error.split_errors(
+ [synthetic_error, user_error, synthetic_error_2, user_error_2]
+ )
+ self.assertEqual([user_error, user_error_2], user_errors)
+ self.assertEqual([synthetic_error, synthetic_error_2], synthetic_errors)
- def test_filter_errors(self):
- user_error = [
- error.error("foo.emb", parser_types.make_location((1, 2), (3, 4)),
- "Bad thing"),
- error.note("foo.emb", parser_types.make_location((3, 4), (5, 6)),
- "Note: bad thing referrent")
- ]
- synthetic_error = [
- error.error("foo.emb", parser_types.make_location((1, 2), (3, 4)),
- "Bad thing"),
- error.note("foo.emb", parser_types.make_location((3, 4), (5, 6), True),
- "Note: bad thing referrent")
- ]
- synthetic_error_2 = [
- error.error("foo.emb",
- parser_types.make_location((8, 9), (10, 11), True),
- "Bad thing"),
- error.note("foo.emb", parser_types.make_location((10, 11), (12, 13)),
- "Note: bad thing referrent")
- ]
- self.assertEqual(
- [user_error],
- error.filter_errors([synthetic_error, user_error, synthetic_error_2]))
+ def test_filter_errors(self):
+ user_error = [
+ error.error(
+ "foo.emb", parser_types.make_location((1, 2), (3, 4)), "Bad thing"
+ ),
+ error.note(
+ "foo.emb",
+ parser_types.make_location((3, 4), (5, 6)),
+ "Note: bad thing referrent",
+ ),
+ ]
+ synthetic_error = [
+ error.error(
+ "foo.emb", parser_types.make_location((1, 2), (3, 4)), "Bad thing"
+ ),
+ error.note(
+ "foo.emb",
+ parser_types.make_location((3, 4), (5, 6), True),
+ "Note: bad thing referrent",
+ ),
+ ]
+ synthetic_error_2 = [
+ error.error(
+ "foo.emb",
+ parser_types.make_location((8, 9), (10, 11), True),
+ "Bad thing",
+ ),
+ error.note(
+ "foo.emb",
+ parser_types.make_location((10, 11), (12, 13)),
+ "Note: bad thing referrent",
+ ),
+ ]
+ self.assertEqual(
+ [user_error],
+ error.filter_errors([synthetic_error, user_error, synthetic_error_2]),
+ )
class FormatErrorsTest(unittest.TestCase):
- def test_format_errors(self):
- errors = [[error.note("foo.emb", parser_types.make_location((3, 4), (3, 6)),
- "note")]]
- sources = {"foo.emb": "x\ny\nz bcd\nq\n"}
- self.assertEqual("foo.emb:3:4: note: note\n"
- "z bcd\n"
- " ^^", error.format_errors(errors, sources))
- bold = error.BOLD
- reset = error.RESET
- white = error.WHITE
- bright_black = error.BRIGHT_BLACK
- bright_green = error.BRIGHT_GREEN
- self.assertEqual(bold + "foo.emb:3:4: " + reset + bright_black + "note: " +
- reset + white + "note\n" +
- reset + white + "z bcd\n" +
- reset + bright_green + " ^^" + reset,
- error.format_errors(errors, sources, use_color=True))
+ def test_format_errors(self):
+ errors = [
+ [error.note("foo.emb", parser_types.make_location((3, 4), (3, 6)), "note")]
+ ]
+ sources = {"foo.emb": "x\ny\nz bcd\nq\n"}
+ self.assertEqual(
+ "foo.emb:3:4: note: note\n" "z bcd\n" " ^^",
+ error.format_errors(errors, sources),
+ )
+ bold = error.BOLD
+ reset = error.RESET
+ white = error.WHITE
+ bright_black = error.BRIGHT_BLACK
+ bright_green = error.BRIGHT_GREEN
+ self.assertEqual(
+ bold
+ + "foo.emb:3:4: "
+ + reset
+ + bright_black
+ + "note: "
+ + reset
+ + white
+ + "note\n"
+ + reset
+ + white
+ + "z bcd\n"
+ + reset
+ + bright_green
+ + " ^^"
+ + reset,
+ error.format_errors(errors, sources, use_color=True),
+ )
+
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/util/expression_parser.py b/compiler/util/expression_parser.py
index 708f23b..3981548 100644
--- a/compiler/util/expression_parser.py
+++ b/compiler/util/expression_parser.py
@@ -20,28 +20,28 @@
def parse(text):
- """Parses text as an Expression.
+ """Parses text as an Expression.
- This parses text using the expression subset of the Emboss grammar, and
- returns an ir_data.Expression. The expression only needs to be syntactically
- valid; it will not go through symbol resolution or type checking. This
- function is not intended to be called on arbitrary input; it asserts that the
- text successfully parses, but does not return errors.
+ This parses text using the expression subset of the Emboss grammar, and
+ returns an ir_data.Expression. The expression only needs to be syntactically
+ valid; it will not go through symbol resolution or type checking. This
+ function is not intended to be called on arbitrary input; it asserts that the
+ text successfully parses, but does not return errors.
- Arguments:
- text: The text of an Emboss expression, like "4 + 5" or "$max(1, a, b)".
+ Arguments:
+ text: The text of an Emboss expression, like "4 + 5" or "$max(1, a, b)".
- Returns:
- An ir_data.Expression corresponding to the textual form.
+ Returns:
+ An ir_data.Expression corresponding to the textual form.
- Raises:
- AssertionError if text is not a well-formed Emboss expression, and
- assertions are enabled.
- """
- tokens, errors = tokenizer.tokenize(text, "")
- assert not errors, "{!r}".format(errors)
- # tokenizer.tokenize always inserts a newline token at the end, which breaks
- # expression parsing.
- parse_result = parser.parse_expression(tokens[:-1])
- assert not parse_result.error, "{!r}".format(parse_result.error)
- return module_ir.build_ir(parse_result.parse_tree)
+ Raises:
+ AssertionError if text is not a well-formed Emboss expression, and
+ assertions are enabled.
+ """
+ tokens, errors = tokenizer.tokenize(text, "")
+ assert not errors, "{!r}".format(errors)
+ # tokenizer.tokenize always inserts a newline token at the end, which breaks
+ # expression parsing.
+ parse_result = parser.parse_expression(tokens[:-1])
+ assert not parse_result.error, "{!r}".format(parse_result.error)
+ return module_ir.build_ir(parse_result.parse_tree)
diff --git a/compiler/util/ir_data.py b/compiler/util/ir_data.py
index 624b09a..af8c2f7 100644
--- a/compiler/util/ir_data.py
+++ b/compiler/util/ir_data.py
@@ -27,79 +27,95 @@
@dataclasses.dataclass
class Message:
- """Base class for IR data objects.
+ """Base class for IR data objects.
- Historically protocol buffers were used for serializing this data which has
- led to some legacy naming conventions and references. In particular this
- class is named `Message` in the sense of a protocol buffer message,
- indicating that it is intended to just be data that is used by other higher
- level services.
+ Historically protocol buffers were used for serializing this data which has
+ led to some legacy naming conventions and references. In particular this
+ class is named `Message` in the sense of a protocol buffer message,
+ indicating that it is intended to just be data that is used by other higher
+ level services.
- There are some other legacy idioms leftover from the protocol buffer-based
- definition such as support for "oneof" and optional fields.
- """
-
- IR_DATACLASS: ClassVar[object] = object()
- field_specs: ClassVar[ir_data_fields.FilteredIrFieldSpecs]
-
- def __post_init__(self):
- """Called by dataclass subclasses after init.
-
- Post-processes any lists passed in to use our custom list type.
+ There are some other legacy idioms leftover from the protocol buffer-based
+ definition such as support for "oneof" and optional fields.
"""
- # Convert any lists passed in to CopyValuesList
- for spec in self.field_specs.sequence_field_specs:
- cur_val = getattr(self, spec.name)
- if isinstance(cur_val, ir_data_fields.TemporaryCopyValuesList):
- copy_val = cur_val.temp_list
- else:
- copy_val = ir_data_fields.CopyValuesList(spec.data_type)
- if cur_val:
- copy_val.shallow_copy(cur_val)
- setattr(self, spec.name, copy_val)
- # This hook adds a 15% overhead to end-to-end code generation in some cases
- # so we guard it in a `__debug__` block. Users can opt-out of this check by
- # running python with the `-O` flag, ie: `python3 -O ./embossc`.
- if __debug__:
- def __setattr__(self, name: str, value) -> None:
- """Debug-only hook that adds basic type checking for ir_data fields."""
- if spec := self.field_specs.all_field_specs.get(name):
- if not (
- # Check if it's the expected type
- isinstance(value, spec.data_type) or
- # Oneof fields are a special case
- spec.is_oneof or
- # Optional fields can be set to None
- (spec.container is ir_data_fields.FieldContainer.OPTIONAL and
- value is None) or
- # Sequences can be a few variants of lists
- (spec.is_sequence and
- isinstance(value, (
- list, ir_data_fields.TemporaryCopyValuesList,
- ir_data_fields.CopyValuesList))) or
- # An enum value can be an int
- (spec.is_enum and isinstance(value, int))):
- raise AttributeError(
- f"Cannot set {value} (type {value.__class__}) for type"
- "{spec.data_type}")
- object.__setattr__(self, name, value)
+ IR_DATACLASS: ClassVar[object] = object()
+ field_specs: ClassVar[ir_data_fields.FilteredIrFieldSpecs]
- # Non-PEP8 name to mimic the Google Protobuf interface.
- def HasField(self, name): # pylint:disable=invalid-name
- """Indicates if this class has the given field defined and it is set."""
- return getattr(self, name, None) is not None
+ def __post_init__(self):
+ """Called by dataclass subclasses after init.
- # Non-PEP8 name to mimic the Google Protobuf interface.
- def WhichOneof(self, oneof_name): # pylint:disable=invalid-name
- """Indicates which field has been set for the oneof value.
+ Post-processes any lists passed in to use our custom list type.
+ """
+ # Convert any lists passed in to CopyValuesList
+ for spec in self.field_specs.sequence_field_specs:
+ cur_val = getattr(self, spec.name)
+ if isinstance(cur_val, ir_data_fields.TemporaryCopyValuesList):
+ copy_val = cur_val.temp_list
+ else:
+ copy_val = ir_data_fields.CopyValuesList(spec.data_type)
+ if cur_val:
+ copy_val.shallow_copy(cur_val)
+ setattr(self, spec.name, copy_val)
- Returns None if no field has been set.
- """
- for field_name, oneof in self.field_specs.oneof_mappings:
- if oneof == oneof_name and self.HasField(field_name):
- return field_name
- return None
+ # This hook adds a 15% overhead to end-to-end code generation in some cases
+ # so we guard it in a `__debug__` block. Users can opt-out of this check by
+ # running python with the `-O` flag, ie: `python3 -O ./embossc`.
+ if __debug__:
+
+ def __setattr__(self, name: str, value) -> None:
+ """Debug-only hook that adds basic type checking for ir_data fields."""
+ if spec := self.field_specs.all_field_specs.get(name):
+ if not (
+ # Check if it's the expected type
+ isinstance(value, spec.data_type)
+ or
+ # Oneof fields are a special case
+ spec.is_oneof
+ or
+ # Optional fields can be set to None
+ (
+ spec.container is ir_data_fields.FieldContainer.OPTIONAL
+ and value is None
+ )
+ or
+ # Sequences can be a few variants of lists
+ (
+ spec.is_sequence
+ and isinstance(
+ value,
+ (
+ list,
+ ir_data_fields.TemporaryCopyValuesList,
+ ir_data_fields.CopyValuesList,
+ ),
+ )
+ )
+ or
+ # An enum value can be an int
+ (spec.is_enum and isinstance(value, int))
+ ):
+ raise AttributeError(
+ f"Cannot set {value} (type {value.__class__}) for type"
+ "{spec.data_type}"
+ )
+ object.__setattr__(self, name, value)
+
+ # Non-PEP8 name to mimic the Google Protobuf interface.
+ def HasField(self, name): # pylint:disable=invalid-name
+ """Indicates if this class has the given field defined and it is set."""
+ return getattr(self, name, None) is not None
+
+ # Non-PEP8 name to mimic the Google Protobuf interface.
+ def WhichOneof(self, oneof_name): # pylint:disable=invalid-name
+ """Indicates which field has been set for the oneof value.
+
+ Returns None if no field has been set.
+ """
+ for field_name, oneof in self.field_specs.oneof_mappings:
+ if oneof == oneof_name and self.HasField(field_name):
+ return field_name
+ return None
################################################################################
@@ -108,28 +124,28 @@
@dataclasses.dataclass
class Position(Message):
- """A zero-width position within a source file."""
+ """A zero-width position within a source file."""
- line: int = 0
- """Line (starts from 1)."""
- column: int = 0
- """Column (starts from 1)."""
+ line: int = 0
+ """Line (starts from 1)."""
+ column: int = 0
+ """Column (starts from 1)."""
@dataclasses.dataclass
class Location(Message):
- """A half-open start:end range within a source file."""
+ """A half-open start:end range within a source file."""
- start: Optional[Position] = None
- """Beginning of the range"""
- end: Optional[Position] = None
- """One column past the end of the range."""
+ start: Optional[Position] = None
+ """Beginning of the range"""
+ end: Optional[Position] = None
+ """One column past the end of the range."""
- is_disjoint_from_parent: Optional[bool] = None
- """True if this Location is outside of the parent object's Location."""
+ is_disjoint_from_parent: Optional[bool] = None
+ """True if this Location is outside of the parent object's Location."""
- is_synthetic: Optional[bool] = None
- """True if this Location's parent was synthesized, and does not directly
+ is_synthetic: Optional[bool] = None
+ """True if this Location's parent was synthesized, and does not directly
appear in the source file.
The Emboss front end uses this field to cull
@@ -139,124 +155,124 @@
@dataclasses.dataclass
class Word(Message):
- """IR for a bare word in the source file.
+ """IR for a bare word in the source file.
- This is used in NameDefinitions and References.
- """
+ This is used in NameDefinitions and References.
+ """
- text: Optional[str] = None
- source_location: Optional[Location] = None
+ text: Optional[str] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class String(Message):
- """IR for a string in the source file."""
+ """IR for a string in the source file."""
- text: Optional[str] = None
- source_location: Optional[Location] = None
+ text: Optional[str] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class Documentation(Message):
- text: Optional[str] = None
- source_location: Optional[Location] = None
+ text: Optional[str] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class BooleanConstant(Message):
- """IR for a boolean constant."""
+ """IR for a boolean constant."""
- value: Optional[bool] = None
- source_location: Optional[Location] = None
+ value: Optional[bool] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class Empty(Message):
- """Placeholder message for automatic element counts for arrays."""
+ """Placeholder message for automatic element counts for arrays."""
- source_location: Optional[Location] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class NumericConstant(Message):
- """IR for any numeric constant."""
+ """IR for any numeric constant."""
- # Numeric constants are stored as decimal strings; this is the simplest way
- # to store the full -2**63..+2**64 range.
- #
- # TODO(bolms): switch back to int, and just use strings during
- # serialization, now that we're free of proto.
- value: Optional[str] = None
- source_location: Optional[Location] = None
+ # Numeric constants are stored as decimal strings; this is the simplest way
+ # to store the full -2**63..+2**64 range.
+ #
+ # TODO(bolms): switch back to int, and just use strings during
+ # serialization, now that we're free of proto.
+ value: Optional[str] = None
+ source_location: Optional[Location] = None
class FunctionMapping(int, enum.Enum):
- """Enum of supported function types"""
+ """Enum of supported function types"""
- UNKNOWN = 0
- ADDITION = 1
- """`+`"""
- SUBTRACTION = 2
- """`-`"""
- MULTIPLICATION = 3
- """`*`"""
- EQUALITY = 4
- """`==`"""
- INEQUALITY = 5
- """`!=`"""
- AND = 6
- """`&&`"""
- OR = 7
- """`||`"""
- LESS = 8
- """`<`"""
- LESS_OR_EQUAL = 9
- """`<=`"""
- GREATER = 10
- """`>`"""
- GREATER_OR_EQUAL = 11
- """`>=`"""
- CHOICE = 12
- """`?:`"""
- MAXIMUM = 13
- """`$max()`"""
- PRESENCE = 14
- """`$present()`"""
- UPPER_BOUND = 15
- """`$upper_bound()`"""
- LOWER_BOUND = 16
- """`$lower_bound()`"""
+ UNKNOWN = 0
+ ADDITION = 1
+ """`+`"""
+ SUBTRACTION = 2
+ """`-`"""
+ MULTIPLICATION = 3
+ """`*`"""
+ EQUALITY = 4
+ """`==`"""
+ INEQUALITY = 5
+ """`!=`"""
+ AND = 6
+ """`&&`"""
+ OR = 7
+ """`||`"""
+ LESS = 8
+ """`<`"""
+ LESS_OR_EQUAL = 9
+ """`<=`"""
+ GREATER = 10
+ """`>`"""
+ GREATER_OR_EQUAL = 11
+ """`>=`"""
+ CHOICE = 12
+ """`?:`"""
+ MAXIMUM = 13
+ """`$max()`"""
+ PRESENCE = 14
+ """`$present()`"""
+ UPPER_BOUND = 15
+ """`$upper_bound()`"""
+ LOWER_BOUND = 16
+ """`$lower_bound()`"""
@dataclasses.dataclass
class Function(Message):
- """IR for a single function (+, -, *, ==, $max, etc.) in an expression."""
+ """IR for a single function (+, -, *, ==, $max, etc.) in an expression."""
- function: Optional[FunctionMapping] = None
- args: list["Expression"] = ir_data_fields.list_field(lambda: Expression)
- function_name: Optional[Word] = None
- source_location: Optional[Location] = None
+ function: Optional[FunctionMapping] = None
+ args: list["Expression"] = ir_data_fields.list_field(lambda: Expression)
+ function_name: Optional[Word] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class CanonicalName(Message):
- """CanonicalName is the unique, absolute name for some object.
+ """CanonicalName is the unique, absolute name for some object.
- A CanonicalName is the unique, absolute name for some object (Type, field,
- etc.) in the IR. It is used both in the definitions of objects ("struct
- Foo"), and in references to objects (a field of type "Foo").
- """
+ A CanonicalName is the unique, absolute name for some object (Type, field,
+ etc.) in the IR. It is used both in the definitions of objects ("struct
+ Foo"), and in references to objects (a field of type "Foo").
+ """
- module_file: str = ir_data_fields.str_field()
- """The module_file is the Module.source_file_name of the Module in which this
+ module_file: str = ir_data_fields.str_field()
+ """The module_file is the Module.source_file_name of the Module in which this
object's definition appears.
Note that the Prelude always has a Module.source_file_name of "", and thus
references to Prelude names will have module_file == "".
"""
- object_path: list[str] = ir_data_fields.list_field(str)
- """The object_path is the canonical path to the object definition within its
+ object_path: list[str] = ir_data_fields.list_field(str)
+ """The object_path is the canonical path to the object definition within its
module file.
For example, the field "bar" would have an object path of
@@ -279,160 +295,160 @@
@dataclasses.dataclass
class NameDefinition(Message):
- """NameDefinition is IR for the name of an object, within the object.
+ """NameDefinition is IR for the name of an object, within the object.
- That is, a TypeDefinition or Field will hold a NameDefinition as its
- name.
- """
+ That is, a TypeDefinition or Field will hold a NameDefinition as its
+ name.
+ """
- name: Optional[Word] = None
- """The name, as directly generated from the source text.
+ name: Optional[Word] = None
+ """The name, as directly generated from the source text.
name.text will match the last element of canonical_name.object_path. Note
that in some cases, the exact string in name.text may not appear in the
source text.
"""
- canonical_name: Optional[CanonicalName] = None
- """The CanonicalName that will appear in References.
+ canonical_name: Optional[CanonicalName] = None
+ """The CanonicalName that will appear in References.
This field is technically redundant: canonical_name.module_file should always
match the source_file_name of the enclosing Module, and
canonical_name.object_path should always match the names of parent nodes.
"""
- is_anonymous: Optional[bool] = None
- """If true, indicates that this is an automatically-generated name, which
+ is_anonymous: Optional[bool] = None
+ """If true, indicates that this is an automatically-generated name, which
should not be visible outside of its immediate namespace.
"""
- source_location: Optional[Location] = None
- """The location of this NameDefinition in source code."""
+ source_location: Optional[Location] = None
+ """The location of this NameDefinition in source code."""
@dataclasses.dataclass
class Reference(Message):
- """A Reference holds the canonical name of something defined elsewhere.
+ """A Reference holds the canonical name of something defined elsewhere.
- For example, take this fragment:
+ For example, take this fragment:
- struct Foo:
- 0:3 UInt size (s)
- 4:s Int:8[] payload
+ struct Foo:
+ 0:3 UInt size (s)
+ 4:s Int:8[] payload
- "Foo", "size", and "payload" will become NameDefinitions in their
- corresponding Field and Message IR objects, while "UInt", the second "s",
- and "Int" are References. Note that the second "s" will have a
- canonical_name.object_path of ["Foo", "size"], not ["Foo", "s"]: the
- Reference always holds the single "true" name of the object, regardless of
- what appears in the .emb.
- """
+ "Foo", "size", and "payload" will become NameDefinitions in their
+ corresponding Field and Message IR objects, while "UInt", the second "s",
+ and "Int" are References. Note that the second "s" will have a
+ canonical_name.object_path of ["Foo", "size"], not ["Foo", "s"]: the
+ Reference always holds the single "true" name of the object, regardless of
+ what appears in the .emb.
+ """
- canonical_name: Optional[CanonicalName] = None
- """The canonical name of the object being referred to.
+ canonical_name: Optional[CanonicalName] = None
+ """The canonical name of the object being referred to.
This name should be used to find the object in the IR.
"""
- source_name: list[Word] = ir_data_fields.list_field(Word)
- """The source_name is the name the user entered in the source file.
+ source_name: list[Word] = ir_data_fields.list_field(Word)
+ """The source_name is the name the user entered in the source file.
The source_name could be either relative or absolute, and may be an alias
(and thus not match any part of the canonical_name). Back ends should use
canonical_name for name lookup, and reserve source_name for error messages.
"""
- is_local_name: Optional[bool] = None
- """If true, then symbol resolution should only look at local names when
+ is_local_name: Optional[bool] = None
+ """If true, then symbol resolution should only look at local names when
resolving source_name.
This is used so that the names of inline types aren't "ambiguous" if there
happens to be another type with the same name at a parent scope.
"""
- # TODO(bolms): Allow absolute paths starting with ".".
+ # TODO(bolms): Allow absolute paths starting with ".".
- source_location: Optional[Location] = None
- """Note that this is the source_location of the *Reference*, not of the
+ source_location: Optional[Location] = None
+ """Note that this is the source_location of the *Reference*, not of the
object to which it refers.
"""
@dataclasses.dataclass
class FieldReference(Message):
- """IR for a "field" or "field.sub.subsub" reference in an expression.
+ """IR for a "field" or "field.sub.subsub" reference in an expression.
- The first element of "path" is the "base" field, which should be directly
- readable in the (runtime) context of the expression. For example:
+ The first element of "path" is the "base" field, which should be directly
+ readable in the (runtime) context of the expression. For example:
- struct Foo:
- 0:1 UInt header_size (h)
- 0:h UInt:8[] header_bytes
+ struct Foo:
+ 0:1 UInt header_size (h)
+ 0:h UInt:8[] header_bytes
- The "h" will translate to ["Foo", "header_size"], which will be the first
- (and in this case only) element of "path".
+ The "h" will translate to ["Foo", "header_size"], which will be the first
+ (and in this case only) element of "path".
- Subsequent path elements should be treated as subfields. For example, in:
+ Subsequent path elements should be treated as subfields. For example, in:
- struct Foo:
- struct Sizes:
- 0:1 UInt header_size
- 1:2 UInt body_size
- 0 [+2] Sizes sizes
- 0 [+sizes.header_size] UInt:8[] header
- sizes.header_size [+sizes.body_size] UInt:8[] body
+ struct Foo:
+ struct Sizes:
+ 0:1 UInt header_size
+ 1:2 UInt body_size
+ 0 [+2] Sizes sizes
+ 0 [+sizes.header_size] UInt:8[] header
+ sizes.header_size [+sizes.body_size] UInt:8[] body
- The references to "sizes.header_size" will have a path of [["Foo",
- "sizes"], ["Foo", "Sizes", "header_size"]]. Note that each path element is
- a fully-qualified reference; some back ends (C++, Python) may only use the
- last element, while others (C) may use the complete path.
+ The references to "sizes.header_size" will have a path of [["Foo",
+ "sizes"], ["Foo", "Sizes", "header_size"]]. Note that each path element is
+ a fully-qualified reference; some back ends (C++, Python) may only use the
+ last element, while others (C) may use the complete path.
- This representation is a bit awkward, and is fundamentally limited to a
- dotted list of static field names. It does not allow an expression like
- `array[n]` on the left side of a `.`. At this point, it is an artifact of
- the era during which I (bolms@) thought I could get away with skipping
- compiler-y things.
- """
+ This representation is a bit awkward, and is fundamentally limited to a
+ dotted list of static field names. It does not allow an expression like
+ `array[n]` on the left side of a `.`. At this point, it is an artifact of
+ the era during which I (bolms@) thought I could get away with skipping
+ compiler-y things.
+ """
- # TODO(bolms): Add composite types to the expression type system, and
- # replace FieldReference with a "member access" Expression kind. Further,
- # move the symbol resolution for FieldReferences that is currently in
- # symbol_resolver.py into type_check.py.
+ # TODO(bolms): Add composite types to the expression type system, and
+ # replace FieldReference with a "member access" Expression kind. Further,
+ # move the symbol resolution for FieldReferences that is currently in
+ # symbol_resolver.py into type_check.py.
- # TODO(bolms): Make the above change before declaring the IR to be "stable".
+ # TODO(bolms): Make the above change before declaring the IR to be "stable".
- path: list[Reference] = ir_data_fields.list_field(Reference)
- source_location: Optional[Location] = None
+ path: list[Reference] = ir_data_fields.list_field(Reference)
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class OpaqueType(Message):
- pass
+ pass
@dataclasses.dataclass
class IntegerType(Message):
- """Type of an integer expression."""
+ """Type of an integer expression."""
- # For optimization, the modular congruence of an integer expression is
- # tracked. This consists of a modulus and a modular_value, such that for
- # all possible values of expression, expression MOD modulus ==
- # modular_value.
- #
- # The modulus may be the special value "infinity" to indicate that the
- # expression's value is exactly modular_value; otherwise, it should be a
- # positive integer.
- #
- # A modulus of 1 places no constraints on the value.
- #
- # The modular_value should always be a nonnegative integer that is smaller
- # than the modulus.
- #
- # Note that this is specifically the *modulus*, which is not equivalent to
- # the value from C's '%' operator when the dividend is negative: in C, -7 %
- # 4 == -3, but the modular_value here would be 1. Python uses modulus: in
- # Python, -7 % 4 == 1.
- modulus: Optional[str] = None
- """The modulus portion of the modular congruence of an integer expression.
+ # For optimization, the modular congruence of an integer expression is
+ # tracked. This consists of a modulus and a modular_value, such that for
+ # all possible values of expression, expression MOD modulus ==
+ # modular_value.
+ #
+ # The modulus may be the special value "infinity" to indicate that the
+ # expression's value is exactly modular_value; otherwise, it should be a
+ # positive integer.
+ #
+ # A modulus of 1 places no constraints on the value.
+ #
+ # The modular_value should always be a nonnegative integer that is smaller
+ # than the modulus.
+ #
+ # Note that this is specifically the *modulus*, which is not equivalent to
+ # the value from C's '%' operator when the dividend is negative: in C, -7 %
+ # 4 == -3, but the modular_value here would be 1. Python uses modulus: in
+ # Python, -7 % 4 == 1.
+ modulus: Optional[str] = None
+ """The modulus portion of the modular congruence of an integer expression.
The modulus may be the special value "infinity" to indicate that the
expression's value is exactly modular_value; otherwise, it should be a
@@ -440,171 +456,165 @@
A modulus of 1 places no constraints on the value.
"""
- modular_value: Optional[str] = None
- """ The modular_value portion of the modular congruence of an integer expression.
+ modular_value: Optional[str] = None
+ """ The modular_value portion of the modular congruence of an integer expression.
The modular_value should always be a nonnegative integer that is smaller
than the modulus.
"""
- # The minimum and maximum values of an integer are tracked and checked so
- # that Emboss can implement reliable arithmetic with no operations
- # overflowing either 64-bit unsigned or 64-bit signed 2's-complement
- # integers.
- #
- # Note that constant subexpressions are allowed to overflow, as long as the
- # final, computed constant value of the subexpression fits in a 64-bit
- # value.
- #
- # The minimum_value may take the value "-infinity", and the maximum_value
- # may take the value "infinity". These sentinel values indicate that
- # Emboss has no bound information for the Expression, and therefore the
- # Expression may only be evaluated during compilation; the back end should
- # never need to compile such an expression into the target language (e.g.,
- # C++).
- minimum_value: Optional[str] = None
- maximum_value: Optional[str] = None
+ # The minimum and maximum values of an integer are tracked and checked so
+ # that Emboss can implement reliable arithmetic with no operations
+ # overflowing either 64-bit unsigned or 64-bit signed 2's-complement
+ # integers.
+ #
+ # Note that constant subexpressions are allowed to overflow, as long as the
+ # final, computed constant value of the subexpression fits in a 64-bit
+ # value.
+ #
+ # The minimum_value may take the value "-infinity", and the maximum_value
+ # may take the value "infinity". These sentinel values indicate that
+ # Emboss has no bound information for the Expression, and therefore the
+ # Expression may only be evaluated during compilation; the back end should
+ # never need to compile such an expression into the target language (e.g.,
+ # C++).
+ minimum_value: Optional[str] = None
+ maximum_value: Optional[str] = None
@dataclasses.dataclass
class BooleanType(Message):
- value: Optional[bool] = None
+ value: Optional[bool] = None
@dataclasses.dataclass
class EnumType(Message):
- name: Optional[Reference] = None
- value: Optional[str] = None
+ name: Optional[Reference] = None
+ value: Optional[str] = None
@dataclasses.dataclass
class ExpressionType(Message):
- opaque: Optional[OpaqueType] = ir_data_fields.oneof_field("type")
- integer: Optional[IntegerType] = ir_data_fields.oneof_field("type")
- boolean: Optional[BooleanType] = ir_data_fields.oneof_field("type")
- enumeration: Optional[EnumType] = ir_data_fields.oneof_field("type")
+ opaque: Optional[OpaqueType] = ir_data_fields.oneof_field("type")
+ integer: Optional[IntegerType] = ir_data_fields.oneof_field("type")
+ boolean: Optional[BooleanType] = ir_data_fields.oneof_field("type")
+ enumeration: Optional[EnumType] = ir_data_fields.oneof_field("type")
@dataclasses.dataclass
class Expression(Message):
- """IR for an expression.
+ """IR for an expression.
- An Expression is a potentially-recursive data structure. It can either
- represent a leaf node (constant or reference) or an operation combining
- other Expressions (function).
- """
+ An Expression is a potentially-recursive data structure. It can either
+ represent a leaf node (constant or reference) or an operation combining
+ other Expressions (function).
+ """
- constant: Optional[NumericConstant] = ir_data_fields.oneof_field("expression")
- constant_reference: Optional[Reference] = ir_data_fields.oneof_field(
- "expression"
- )
- function: Optional[Function] = ir_data_fields.oneof_field("expression")
- field_reference: Optional[FieldReference] = ir_data_fields.oneof_field(
- "expression"
- )
- boolean_constant: Optional[BooleanConstant] = ir_data_fields.oneof_field(
- "expression"
- )
- builtin_reference: Optional[Reference] = ir_data_fields.oneof_field(
- "expression"
- )
+ constant: Optional[NumericConstant] = ir_data_fields.oneof_field("expression")
+ constant_reference: Optional[Reference] = ir_data_fields.oneof_field("expression")
+ function: Optional[Function] = ir_data_fields.oneof_field("expression")
+ field_reference: Optional[FieldReference] = ir_data_fields.oneof_field("expression")
+ boolean_constant: Optional[BooleanConstant] = ir_data_fields.oneof_field(
+ "expression"
+ )
+ builtin_reference: Optional[Reference] = ir_data_fields.oneof_field("expression")
- type: Optional[ExpressionType] = None
- source_location: Optional[Location] = None
+ type: Optional[ExpressionType] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class ArrayType(Message):
- """IR for an array type ("Int:8[12]" or "Message[2]" or "UInt[3][2]")."""
+ """IR for an array type ("Int:8[12]" or "Message[2]" or "UInt[3][2]")."""
- base_type: Optional["Type"] = None
+ base_type: Optional["Type"] = None
- element_count: Optional[Expression] = ir_data_fields.oneof_field("size")
- automatic: Optional[Empty] = ir_data_fields.oneof_field("size")
+ element_count: Optional[Expression] = ir_data_fields.oneof_field("size")
+ automatic: Optional[Empty] = ir_data_fields.oneof_field("size")
- source_location: Optional[Location] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class AtomicType(Message):
- """IR for a non-array type ("UInt" or "Foo(Version.SIX)")."""
+ """IR for a non-array type ("UInt" or "Foo(Version.SIX)")."""
- reference: Optional[Reference] = None
- runtime_parameter: list[Expression] = ir_data_fields.list_field(Expression)
- source_location: Optional[Location] = None
+ reference: Optional[Reference] = None
+ runtime_parameter: list[Expression] = ir_data_fields.list_field(Expression)
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class Type(Message):
- """IR for a type reference ("UInt", "Int:8[12]", etc.)."""
+ """IR for a type reference ("UInt", "Int:8[12]", etc.)."""
- atomic_type: Optional[AtomicType] = ir_data_fields.oneof_field("type")
- array_type: Optional[ArrayType] = ir_data_fields.oneof_field("type")
+ atomic_type: Optional[AtomicType] = ir_data_fields.oneof_field("type")
+ array_type: Optional[ArrayType] = ir_data_fields.oneof_field("type")
- size_in_bits: Optional[Expression] = None
- source_location: Optional[Location] = None
+ size_in_bits: Optional[Expression] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class AttributeValue(Message):
- """IR for a attribute value."""
+ """IR for a attribute value."""
- # TODO(bolms): Make String a type of Expression, and replace
- # AttributeValue with Expression.
- expression: Optional[Expression] = ir_data_fields.oneof_field("value")
- string_constant: Optional[String] = ir_data_fields.oneof_field("value")
+ # TODO(bolms): Make String a type of Expression, and replace
+ # AttributeValue with Expression.
+ expression: Optional[Expression] = ir_data_fields.oneof_field("value")
+ string_constant: Optional[String] = ir_data_fields.oneof_field("value")
- source_location: Optional[Location] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class Attribute(Message):
- """IR for a [name = value] attribute."""
+ """IR for a [name = value] attribute."""
- name: Optional[Word] = None
- value: Optional[AttributeValue] = None
- back_end: Optional[Word] = None
- is_default: Optional[bool] = None
- source_location: Optional[Location] = None
+ name: Optional[Word] = None
+ value: Optional[AttributeValue] = None
+ back_end: Optional[Word] = None
+ is_default: Optional[bool] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class WriteTransform(Message):
- """IR which defines an expression-based virtual field write scheme.
+ """IR which defines an expression-based virtual field write scheme.
- E.g., for a virtual field like `x_plus_one`:
+ E.g., for a virtual field like `x_plus_one`:
- struct Foo:
- 0 [+1] UInt x
- let x_plus_one = x + 1
+ struct Foo:
+ 0 [+1] UInt x
+ let x_plus_one = x + 1
- ... the `WriteMethod` would be `transform`, with `$logical_value - 1` for
- `function_body` and `x` for `destination`.
- """
+ ... the `WriteMethod` would be `transform`, with `$logical_value - 1` for
+ `function_body` and `x` for `destination`.
+ """
- function_body: Optional[Expression] = None
- destination: Optional[FieldReference] = None
+ function_body: Optional[Expression] = None
+ destination: Optional[FieldReference] = None
@dataclasses.dataclass
class WriteMethod(Message):
- """IR which defines the method used for writing to a virtual field."""
+ """IR which defines the method used for writing to a virtual field."""
- physical: Optional[bool] = ir_data_fields.oneof_field("method")
- """A physical Field can be written directly."""
+ physical: Optional[bool] = ir_data_fields.oneof_field("method")
+ """A physical Field can be written directly."""
- read_only: Optional[bool] = ir_data_fields.oneof_field("method")
- """A read_only Field cannot be written."""
+ read_only: Optional[bool] = ir_data_fields.oneof_field("method")
+ """A read_only Field cannot be written."""
- alias: Optional[FieldReference] = ir_data_fields.oneof_field("method")
- """An alias is a direct, untransformed forward of another field; it can be
+ alias: Optional[FieldReference] = ir_data_fields.oneof_field("method")
+ """An alias is a direct, untransformed forward of another field; it can be
implemented by directly returning a reference to the aliased field.
Aliases are the only kind of virtual field that may have an opaque type.
"""
- transform: Optional[WriteTransform] = ir_data_fields.oneof_field("method")
- """A transform is a way of turning a logical value into a value which should
+ transform: Optional[WriteTransform] = ir_data_fields.oneof_field("method")
+ """A transform is a way of turning a logical value into a value which should
be written to another field.
A virtual field like `let y = x + 1` would
@@ -615,47 +625,47 @@
@dataclasses.dataclass
class FieldLocation(Message):
- """IR for a field location."""
+ """IR for a field location."""
- start: Optional[Expression] = None
- size: Optional[Expression] = None
- source_location: Optional[Location] = None
+ start: Optional[Expression] = None
+ size: Optional[Expression] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class Field(Message): # pylint:disable=too-many-instance-attributes
- """IR for a field in a struct definition.
+ """IR for a field in a struct definition.
- There are two kinds of Field: physical fields have location and (physical)
- type; virtual fields have read_transform. Although there are differences,
- in many situations physical and virtual fields are treated the same way,
- and they can be freely intermingled in the source file.
- """
+ There are two kinds of Field: physical fields have location and (physical)
+ type; virtual fields have read_transform. Although there are differences,
+ in many situations physical and virtual fields are treated the same way,
+ and they can be freely intermingled in the source file.
+ """
- location: Optional[FieldLocation] = None
- """The physical location of the field."""
- type: Optional[Type] = None
- """The physical type of the field."""
+ location: Optional[FieldLocation] = None
+ """The physical location of the field."""
+ type: Optional[Type] = None
+ """The physical type of the field."""
- read_transform: Optional[Expression] = None
- """The value of a virtual field."""
+ read_transform: Optional[Expression] = None
+ """The value of a virtual field."""
- write_method: Optional[WriteMethod] = None
- """How this virtual field should be written."""
+ write_method: Optional[WriteMethod] = None
+ """How this virtual field should be written."""
- name: Optional[NameDefinition] = None
- """The name of the field."""
- abbreviation: Optional[Word] = None
- """An optional short name for the field, only visible inside the enclosing bits/struct."""
- attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
- """Field-specific attributes."""
- documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
- """Field-specific documentation."""
+ name: Optional[NameDefinition] = None
+ """The name of the field."""
+ abbreviation: Optional[Word] = None
+ """An optional short name for the field, only visible inside the enclosing bits/struct."""
+ attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
+ """Field-specific attributes."""
+ documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
+ """Field-specific documentation."""
- # TODO(bolms): Document conditional fields better, and replace some of this
- # explanation with a reference to the documentation.
- existence_condition: Optional[Expression] = None
- """The field only exists when existence_condition evaluates to true.
+ # TODO(bolms): Document conditional fields better, and replace some of this
+ # explanation with a reference to the documentation.
+ existence_condition: Optional[Expression] = None
+ """The field only exists when existence_condition evaluates to true.
For example:
```
@@ -690,17 +700,17 @@
`bar`: those fields only conditionally exist in the structure.
"""
- source_location: Optional[Location] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class Structure(Message):
- """IR for a bits or struct definition."""
+ """IR for a bits or struct definition."""
- field: list[Field] = ir_data_fields.list_field(Field)
+ field: list[Field] = ir_data_fields.list_field(Field)
- fields_in_dependency_order: list[int] = ir_data_fields.list_field(int)
- """The fields in `field` are listed in the order they appear in the original
+ fields_in_dependency_order: list[int] = ir_data_fields.list_field(int)
+ """The fields in `field` are listed in the order they appear in the original
.emb.
For text format output, this can lead to poor results. Take the following
@@ -734,66 +744,66 @@
be `{ 0, 1, 2, 3, ... }`.
"""
- source_location: Optional[Location] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class External(Message):
- """IR for an external type declaration."""
+ """IR for an external type declaration."""
- # Externals have no values other than name and attribute list, which are
- # common to all type definitions.
+ # Externals have no values other than name and attribute list, which are
+ # common to all type definitions.
- source_location: Optional[Location] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class EnumValue(Message):
- """IR for a single value within an enumerated type."""
+ """IR for a single value within an enumerated type."""
- name: Optional[NameDefinition] = None
- """The name of the enum value."""
- value: Optional[Expression] = None
- """The value of the enum value."""
- documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
- """Value-specific documentation."""
- attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
- """Value-specific attributes."""
+ name: Optional[NameDefinition] = None
+ """The name of the enum value."""
+ value: Optional[Expression] = None
+ """The value of the enum value."""
+ documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
+ """Value-specific documentation."""
+ attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
+ """Value-specific attributes."""
- source_location: Optional[Location] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class Enum(Message):
- """IR for an enumerated type definition."""
+ """IR for an enumerated type definition."""
- value: list[EnumValue] = ir_data_fields.list_field(EnumValue)
- source_location: Optional[Location] = None
+ value: list[EnumValue] = ir_data_fields.list_field(EnumValue)
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class Import(Message):
- """IR for an import statement in a module."""
+ """IR for an import statement in a module."""
- file_name: Optional[String] = None
- """The file to import."""
- local_name: Optional[Word] = None
- """The name to use within this module."""
- source_location: Optional[Location] = None
+ file_name: Optional[String] = None
+ """The file to import."""
+ local_name: Optional[Word] = None
+ """The name to use within this module."""
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class RuntimeParameter(Message):
- """IR for a runtime parameter definition."""
+ """IR for a runtime parameter definition."""
- name: Optional[NameDefinition] = None
- """The name of the parameter."""
- type: Optional[ExpressionType] = None
- """The type of the parameter."""
+ name: Optional[NameDefinition] = None
+ """The name of the parameter."""
+ type: Optional[ExpressionType] = None
+ """The type of the parameter."""
- # TODO(bolms): Actually implement the set builder type notation.
- physical_type_alias: Optional[Type] = None
- """For convenience and readability, physical types may be used in the .emb
+ # TODO(bolms): Actually implement the set builder type notation.
+ physical_type_alias: Optional[Type] = None
+ """For convenience and readability, physical types may be used in the .emb
source instead of a full expression type.
That way, users can write
@@ -809,80 +819,78 @@
is filled in after initial parsing is finished.
"""
- source_location: Optional[Location] = None
+ source_location: Optional[Location] = None
class AddressableUnit(int, enum.Enum):
- """The "addressable unit" is the size of the smallest unit that can be read
+ """The "addressable unit" is the size of the smallest unit that can be read
- from the backing store that this type expects. For `struct`s, this is
- BYTE; for `enum`s and `bits`, this is BIT, and for `external`s it depends
- on the specific type
- """
+ from the backing store that this type expects. For `struct`s, this is
+ BYTE; for `enum`s and `bits`, this is BIT, and for `external`s it depends
+ on the specific type
+ """
- NONE = 0
- BIT = 1
- BYTE = 8
+ NONE = 0
+ BIT = 1
+ BYTE = 8
@dataclasses.dataclass
class TypeDefinition(Message):
- """Container IR for a type definition (struct, union, etc.)"""
+ """Container IR for a type definition (struct, union, etc.)"""
- external: Optional[External] = ir_data_fields.oneof_field("type")
- enumeration: Optional[Enum] = ir_data_fields.oneof_field("type")
- structure: Optional[Structure] = ir_data_fields.oneof_field("type")
+ external: Optional[External] = ir_data_fields.oneof_field("type")
+ enumeration: Optional[Enum] = ir_data_fields.oneof_field("type")
+ structure: Optional[Structure] = ir_data_fields.oneof_field("type")
- name: Optional[NameDefinition] = None
- """The name of the type."""
- attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
- """All attributes attached to the type."""
- documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
- """Docs for the type."""
- # pylint:disable=undefined-variable
- subtype: list["TypeDefinition"] = ir_data_fields.list_field(
- lambda: TypeDefinition
- )
- """Subtypes of this type."""
- addressable_unit: Optional[AddressableUnit] = None
+ name: Optional[NameDefinition] = None
+ """The name of the type."""
+ attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
+ """All attributes attached to the type."""
+ documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
+ """Docs for the type."""
+ # pylint:disable=undefined-variable
+ subtype: list["TypeDefinition"] = ir_data_fields.list_field(lambda: TypeDefinition)
+ """Subtypes of this type."""
+ addressable_unit: Optional[AddressableUnit] = None
- runtime_parameter: list[RuntimeParameter] = ir_data_fields.list_field(
- RuntimeParameter
- )
- """If the type requires parameters at runtime, these are its parameters.
+ runtime_parameter: list[RuntimeParameter] = ir_data_fields.list_field(
+ RuntimeParameter
+ )
+ """If the type requires parameters at runtime, these are its parameters.
These are currently only allowed on structures, but in the future they
should be allowed on externals.
"""
- source_location: Optional[Location] = None
+ source_location: Optional[Location] = None
@dataclasses.dataclass
class Module(Message):
- """The IR for an individual Emboss module (file)."""
+ """The IR for an individual Emboss module (file)."""
- attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
- """Module-level attributes."""
- type: list[TypeDefinition] = ir_data_fields.list_field(TypeDefinition)
- """Module-level type definitions."""
- documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
- """Module-level docs."""
- foreign_import: list[Import] = ir_data_fields.list_field(Import)
- """Other modules imported."""
- source_text: Optional[str] = None
- """The original source code."""
- source_location: Optional[Location] = None
- """Source code covered by this IR."""
- source_file_name: Optional[str] = None
- """Name of the source file."""
+ attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
+ """Module-level attributes."""
+ type: list[TypeDefinition] = ir_data_fields.list_field(TypeDefinition)
+ """Module-level type definitions."""
+ documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
+ """Module-level docs."""
+ foreign_import: list[Import] = ir_data_fields.list_field(Import)
+ """Other modules imported."""
+ source_text: Optional[str] = None
+ """The original source code."""
+ source_location: Optional[Location] = None
+ """Source code covered by this IR."""
+ source_file_name: Optional[str] = None
+ """Name of the source file."""
@dataclasses.dataclass
class EmbossIr(Message):
- """The top-level IR for an Emboss module and all of its dependencies."""
+ """The top-level IR for an Emboss module and all of its dependencies."""
- module: list[Module] = ir_data_fields.list_field(Module)
- """All modules.
+ module: list[Module] = ir_data_fields.list_field(Module)
+ """All modules.
The first entry will be the main module; back ends should
generate code corresponding to that module. The second entry will be the
diff --git a/compiler/util/ir_data_fields.py b/compiler/util/ir_data_fields.py
index 52f3a8a..002ea3b 100644
--- a/compiler/util/ir_data_fields.py
+++ b/compiler/util/ir_data_fields.py
@@ -57,11 +57,11 @@
class IrDataclassInstance(Protocol):
- """Type bound for an IR dataclass instance."""
+ """Type bound for an IR dataclass instance."""
- __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]
- IR_DATACLASS: ClassVar[object]
- field_specs: ClassVar["FilteredIrFieldSpecs"]
+ __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]
+ IR_DATACLASS: ClassVar[object]
+ field_specs: ClassVar["FilteredIrFieldSpecs"]
IrDataT = TypeVar("IrDataT", bound=IrDataclassInstance)
@@ -72,249 +72,245 @@
def _is_ir_dataclass(obj):
- return hasattr(obj, _IR_DATACLASSS_ATTR)
+ return hasattr(obj, _IR_DATACLASSS_ATTR)
class CopyValuesList(list[CopyValuesListT]):
- """A list that makes copies of any value that is inserted"""
+ """A list that makes copies of any value that is inserted"""
- def __init__(
- self, value_type: CopyValuesListT, iterable: Optional[Iterable] = None
- ):
- if iterable:
- super().__init__(iterable)
- else:
- super().__init__()
- self.value_type = value_type
+ def __init__(
+ self, value_type: CopyValuesListT, iterable: Optional[Iterable] = None
+ ):
+ if iterable:
+ super().__init__(iterable)
+ else:
+ super().__init__()
+ self.value_type = value_type
- def _copy(self, obj: Any):
- if _is_ir_dataclass(obj):
- return copy(obj)
- return self.value_type(obj)
+ def _copy(self, obj: Any):
+ if _is_ir_dataclass(obj):
+ return copy(obj)
+ return self.value_type(obj)
- def extend(self, iterable: Iterable) -> None:
- return super().extend([self._copy(i) for i in iterable])
+ def extend(self, iterable: Iterable) -> None:
+ return super().extend([self._copy(i) for i in iterable])
- def shallow_copy(self, iterable: Iterable) -> None:
- """Explicitly performs a shallow copy of the provided list"""
- return super().extend(iterable)
+ def shallow_copy(self, iterable: Iterable) -> None:
+ """Explicitly performs a shallow copy of the provided list"""
+ return super().extend(iterable)
- def append(self, obj: Any) -> None:
- return super().append(self._copy(obj))
+ def append(self, obj: Any) -> None:
+ return super().append(self._copy(obj))
- def insert(self, index: SupportsIndex, obj: Any) -> None:
- return super().insert(index, self._copy(obj))
+ def insert(self, index: SupportsIndex, obj: Any) -> None:
+ return super().insert(index, self._copy(obj))
class TemporaryCopyValuesList(NamedTuple):
- """Class used to temporarily hold a CopyValuesList while copying and
- constructing an IR dataclass.
- """
+ """Class used to temporarily hold a CopyValuesList while copying and
+ constructing an IR dataclass.
+ """
- temp_list: CopyValuesList
+ temp_list: CopyValuesList
class FieldContainer(enum.Enum):
- """Indicates a fields container type"""
+ """Indicates a fields container type"""
- NONE = 0
- OPTIONAL = 1
- LIST = 2
+ NONE = 0
+ OPTIONAL = 1
+ LIST = 2
class FieldSpec(NamedTuple):
- """Indicates the container and type of a field.
+ """Indicates the container and type of a field.
- `FieldSpec` objects are accessed millions of times during runs so we cache as
- many operations as possible.
- - `is_dataclass`: `dataclasses.is_dataclass(data_type)`
- - `is_sequence`: `container is FieldContainer.LIST`
- - `is_enum`: `issubclass(data_type, enum.Enum)`
- - `is_oneof`: `oneof is not None`
+ `FieldSpec` objects are accessed millions of times during runs so we cache as
+ many operations as possible.
+ - `is_dataclass`: `dataclasses.is_dataclass(data_type)`
+ - `is_sequence`: `container is FieldContainer.LIST`
+ - `is_enum`: `issubclass(data_type, enum.Enum)`
+ - `is_oneof`: `oneof is not None`
- Use `make_field_spec` to automatically fill in the cached operations.
- """
+ Use `make_field_spec` to automatically fill in the cached operations.
+ """
- name: str
- data_type: type
- container: FieldContainer
- oneof: Optional[str]
- is_dataclass: bool
- is_sequence: bool
- is_enum: bool
- is_oneof: bool
+ name: str
+ data_type: type
+ container: FieldContainer
+ oneof: Optional[str]
+ is_dataclass: bool
+ is_sequence: bool
+ is_enum: bool
+ is_oneof: bool
def make_field_spec(
name: str, data_type: type, container: FieldContainer, oneof: Optional[str]
):
- """Builds a field spec with cached type queries."""
- return FieldSpec(
- name,
- data_type,
- container,
- oneof,
- is_dataclass=_is_ir_dataclass(data_type),
- is_sequence=container is FieldContainer.LIST,
- is_enum=issubclass(data_type, enum.Enum),
- is_oneof=oneof is not None,
- )
+ """Builds a field spec with cached type queries."""
+ return FieldSpec(
+ name,
+ data_type,
+ container,
+ oneof,
+ is_dataclass=_is_ir_dataclass(data_type),
+ is_sequence=container is FieldContainer.LIST,
+ is_enum=issubclass(data_type, enum.Enum),
+ is_oneof=oneof is not None,
+ )
def build_default(field_spec: FieldSpec):
- """Builds a default instance of the given field"""
- if field_spec.is_sequence:
- return CopyValuesList(field_spec.data_type)
- if field_spec.is_enum:
- return field_spec.data_type(int())
- return field_spec.data_type()
+ """Builds a default instance of the given field"""
+ if field_spec.is_sequence:
+ return CopyValuesList(field_spec.data_type)
+ if field_spec.is_enum:
+ return field_spec.data_type(int())
+ return field_spec.data_type()
class FilteredIrFieldSpecs:
- """Provides cached views of an IR dataclass' fields."""
+ """Provides cached views of an IR dataclass' fields."""
- def __init__(self, specs: Mapping[str, FieldSpec]):
- self.all_field_specs = specs
- self.field_specs = tuple(specs.values())
- self.dataclass_field_specs = {
- k: v for k, v in specs.items() if v.is_dataclass
- }
- self.oneof_field_specs = {k: v for k, v in specs.items() if v.is_oneof}
- self.sequence_field_specs = tuple(
- v for v in specs.values() if v.is_sequence
- )
- self.oneof_mappings = tuple(
- (k, v.oneof) for k, v in self.oneof_field_specs.items() if v.oneof
- )
+ def __init__(self, specs: Mapping[str, FieldSpec]):
+ self.all_field_specs = specs
+ self.field_specs = tuple(specs.values())
+ self.dataclass_field_specs = {k: v for k, v in specs.items() if v.is_dataclass}
+ self.oneof_field_specs = {k: v for k, v in specs.items() if v.is_oneof}
+ self.sequence_field_specs = tuple(v for v in specs.values() if v.is_sequence)
+ self.oneof_mappings = tuple(
+ (k, v.oneof) for k, v in self.oneof_field_specs.items() if v.oneof
+ )
def all_ir_classes(mod):
- """Retrieves a list of all IR dataclass definitions in the given module."""
- return (
- v
- for v in mod.__dict__.values()
- if isinstance(type, v.__class__) and _is_ir_dataclass(v)
- )
+ """Retrieves a list of all IR dataclass definitions in the given module."""
+ return (
+ v
+ for v in mod.__dict__.values()
+ if isinstance(type, v.__class__) and _is_ir_dataclass(v)
+ )
class IrDataclassSpecs:
- """Maintains a cache of all IR dataclass specs."""
+ """Maintains a cache of all IR dataclass specs."""
- spec_cache: MutableMapping[type, FilteredIrFieldSpecs] = {}
+ spec_cache: MutableMapping[type, FilteredIrFieldSpecs] = {}
- @classmethod
- def get_mod_specs(cls, mod):
- """Gets the IR dataclass specs for the given module."""
- return {
- ir_class: FilteredIrFieldSpecs(_field_specs(ir_class))
- for ir_class in all_ir_classes(mod)
- }
+ @classmethod
+ def get_mod_specs(cls, mod):
+ """Gets the IR dataclass specs for the given module."""
+ return {
+ ir_class: FilteredIrFieldSpecs(_field_specs(ir_class))
+ for ir_class in all_ir_classes(mod)
+ }
- @classmethod
- def get_specs(cls, data_class):
- """Gets the field specs for the given class. The specs will be cached."""
- if data_class not in cls.spec_cache:
- mod = sys.modules[data_class.__module__]
- cls.spec_cache.update(cls.get_mod_specs(mod))
- return cls.spec_cache[data_class]
+ @classmethod
+ def get_specs(cls, data_class):
+ """Gets the field specs for the given class. The specs will be cached."""
+ if data_class not in cls.spec_cache:
+ mod = sys.modules[data_class.__module__]
+ cls.spec_cache.update(cls.get_mod_specs(mod))
+ return cls.spec_cache[data_class]
def cache_message_specs(mod, cls):
- """Adds a cached `field_specs` attribute to IR dataclasses in `mod`
- excluding the given base `cls`.
+ """Adds a cached `field_specs` attribute to IR dataclasses in `mod`
+ excluding the given base `cls`.
- This needs to be done after the dataclass decorators run and create the
- wrapped classes.
- """
- for data_class in all_ir_classes(mod):
- if data_class is not cls:
- data_class.field_specs = IrDataclassSpecs.get_specs(data_class)
+ This needs to be done after the dataclass decorators run and create the
+ wrapped classes.
+ """
+ for data_class in all_ir_classes(mod):
+ if data_class is not cls:
+ data_class.field_specs = IrDataclassSpecs.get_specs(data_class)
def _field_specs(cls: type[IrDataT]) -> Mapping[str, FieldSpec]:
- """Gets the IR data field names and types for the given IR data class"""
- # Get the dataclass fields
- class_fields = dataclasses.fields(cast(Any, cls))
+ """Gets the IR data field names and types for the given IR data class"""
+ # Get the dataclass fields
+ class_fields = dataclasses.fields(cast(Any, cls))
- # Pre-python 3.11 (maybe pre 3.10) `get_type_hints` will substitute
- # `builtins.Expression` for 'Expression' rather than `ir_data.Expression`.
- # Instead we manually subsitute the type by extracting the list of classes
- # from the class' module and manually substituting.
- mod_ns = {
- k: v
- for k, v in sys.modules[cls.__module__].__dict__.items()
- if isinstance(type, v.__class__)
- }
+ # Pre-python 3.11 (maybe pre 3.10) `get_type_hints` will substitute
+ # `builtins.Expression` for 'Expression' rather than `ir_data.Expression`.
+ # Instead we manually subsitute the type by extracting the list of classes
+ # from the class' module and manually substituting.
+ mod_ns = {
+ k: v
+ for k, v in sys.modules[cls.__module__].__dict__.items()
+ if isinstance(type, v.__class__)
+ }
- # Now extract the concrete type out of optionals
- result: MutableMapping[str, FieldSpec] = {}
- for class_field in class_fields:
- if class_field.name.startswith("_"):
- continue
- container_type = FieldContainer.NONE
- type_hint = class_field.type
- oneof = class_field.metadata.get("oneof")
+ # Now extract the concrete type out of optionals
+ result: MutableMapping[str, FieldSpec] = {}
+ for class_field in class_fields:
+ if class_field.name.startswith("_"):
+ continue
+ container_type = FieldContainer.NONE
+ type_hint = class_field.type
+ oneof = class_field.metadata.get("oneof")
- # Check if this type is wrapped
- origin = get_origin(type_hint)
- # Get the wrapped types if there are any
- args = get_args(type_hint)
- if origin is not None:
- # Extract the type.
- type_hint = args[0]
+ # Check if this type is wrapped
+ origin = get_origin(type_hint)
+ # Get the wrapped types if there are any
+ args = get_args(type_hint)
+ if origin is not None:
+ # Extract the type.
+ type_hint = args[0]
- # Underneath the hood `typing.Optional` is just a `Union[T, None]` so we
- # have to check if it's a `Union` instead of just using `Optional`.
- if origin == Union:
- # Make sure this is an `Optional` and not another `Union` type.
- assert len(args) == 2 and args[1] == type(None)
- container_type = FieldContainer.OPTIONAL
- elif origin == list:
- container_type = FieldContainer.LIST
- else:
- raise TypeError(f"Field has invalid container type: {origin}")
+ # Underneath the hood `typing.Optional` is just a `Union[T, None]` so we
+ # have to check if it's a `Union` instead of just using `Optional`.
+ if origin == Union:
+ # Make sure this is an `Optional` and not another `Union` type.
+ assert len(args) == 2 and args[1] == type(None)
+ container_type = FieldContainer.OPTIONAL
+ elif origin == list:
+ container_type = FieldContainer.LIST
+ else:
+ raise TypeError(f"Field has invalid container type: {origin}")
- # Resolve any forward references.
- if isinstance(type_hint, str):
- type_hint = mod_ns[type_hint]
- if isinstance(type_hint, ForwardRef):
- type_hint = mod_ns[type_hint.__forward_arg__]
+ # Resolve any forward references.
+ if isinstance(type_hint, str):
+ type_hint = mod_ns[type_hint]
+ if isinstance(type_hint, ForwardRef):
+ type_hint = mod_ns[type_hint.__forward_arg__]
- result[class_field.name] = make_field_spec(
- class_field.name, type_hint, container_type, oneof
- )
+ result[class_field.name] = make_field_spec(
+ class_field.name, type_hint, container_type, oneof
+ )
- return result
+ return result
def field_specs(obj: Union[IrDataT, type[IrDataT]]) -> Mapping[str, FieldSpec]:
- """Retrieves the fields specs for the the give data type.
+ """Retrieves the fields specs for the the give data type.
- The results of this method are cached to reduce lookup overhead.
- """
- cls = obj if isinstance(obj, type) else type(obj)
- if cls is type(None):
- raise TypeError("field_specs called with invalid type: NoneType")
- return IrDataclassSpecs.get_specs(cls).all_field_specs
+ The results of this method are cached to reduce lookup overhead.
+ """
+ cls = obj if isinstance(obj, type) else type(obj)
+ if cls is type(None):
+ raise TypeError("field_specs called with invalid type: NoneType")
+ return IrDataclassSpecs.get_specs(cls).all_field_specs
def fields_and_values(
ir: IrDataT,
value_filt: Optional[Callable[[Any], bool]] = None,
):
- """Retrieves the fields and their values for a given IR data class.
+ """Retrieves the fields and their values for a given IR data class.
- Args:
- ir: The IR data class or a read-only wrapper of an IR data class.
- value_filt: Optional filter used to exclude values.
- """
- set_fields: list[Tuple[FieldSpec, Any]] = []
- specs: FilteredIrFieldSpecs = ir.field_specs
- for spec in specs.field_specs:
- value = getattr(ir, spec.name)
- if not value_filt or value_filt(value):
- set_fields.append((spec, value))
- return set_fields
+ Args:
+ ir: The IR data class or a read-only wrapper of an IR data class.
+ value_filt: Optional filter used to exclude values.
+ """
+ set_fields: list[Tuple[FieldSpec, Any]] = []
+ specs: FilteredIrFieldSpecs = ir.field_specs
+ for spec in specs.field_specs:
+ value = getattr(ir, spec.name)
+ if not value_filt or value_filt(value):
+ set_fields.append((spec, value))
+ return set_fields
# `copy` is one of the hottest paths of embossc. We've taken steps to
@@ -333,117 +329,119 @@
# 5. None checks are only done in `copy()`, `_copy_set_fields` only
# references `_copy()` to avoid this step.
def _copy_set_fields(ir: IrDataT):
- values: MutableMapping[str, Any] = {}
+ values: MutableMapping[str, Any] = {}
- specs: FilteredIrFieldSpecs = ir.field_specs
- for spec in specs.field_specs:
- value = getattr(ir, spec.name)
- if value is not None:
- if spec.is_sequence:
- if spec.is_dataclass:
- copy_value = CopyValuesList(spec.data_type, (_copy(v) for v in value))
- value = TemporaryCopyValuesList(copy_value)
- else:
- copy_value = CopyValuesList(spec.data_type, value)
- value = TemporaryCopyValuesList(copy_value)
- elif spec.is_dataclass:
- value = _copy(value)
- values[spec.name] = value
- return values
+ specs: FilteredIrFieldSpecs = ir.field_specs
+ for spec in specs.field_specs:
+ value = getattr(ir, spec.name)
+ if value is not None:
+ if spec.is_sequence:
+ if spec.is_dataclass:
+ copy_value = CopyValuesList(
+ spec.data_type, (_copy(v) for v in value)
+ )
+ value = TemporaryCopyValuesList(copy_value)
+ else:
+ copy_value = CopyValuesList(spec.data_type, value)
+ value = TemporaryCopyValuesList(copy_value)
+ elif spec.is_dataclass:
+ value = _copy(value)
+ values[spec.name] = value
+ return values
def _copy(ir: IrDataT) -> IrDataT:
- return type(ir)(**_copy_set_fields(ir)) # type: ignore[misc]
+ return type(ir)(**_copy_set_fields(ir)) # type: ignore[misc]
def copy(ir: IrDataT) -> Optional[IrDataT]:
- """Creates a copy of the given IR data class"""
- if not ir:
- return None
- return _copy(ir)
+ """Creates a copy of the given IR data class"""
+ if not ir:
+ return None
+ return _copy(ir)
def update(ir: IrDataT, template: IrDataT):
- """Updates `ir`s fields with all set fields in the template."""
- for k, v in _copy_set_fields(template).items():
- if isinstance(v, TemporaryCopyValuesList):
- v = v.temp_list
- setattr(ir, k, v)
+ """Updates `ir`s fields with all set fields in the template."""
+ for k, v in _copy_set_fields(template).items():
+ if isinstance(v, TemporaryCopyValuesList):
+ v = v.temp_list
+ setattr(ir, k, v)
class OneOfField:
- """Decorator for a "oneof" field.
+ """Decorator for a "oneof" field.
- Tracks when the field is set and will unset othe fields in the associated
- oneof group.
+ Tracks when the field is set and will unset othe fields in the associated
+ oneof group.
- Note: Decorators only work if dataclass slots aren't used.
- """
+ Note: Decorators only work if dataclass slots aren't used.
+ """
- def __init__(self, oneof: str) -> None:
- super().__init__()
- self.oneof = oneof
- self.owner_type = None
- self.proxy_name: str = ""
- self.name: str = ""
+ def __init__(self, oneof: str) -> None:
+ super().__init__()
+ self.oneof = oneof
+ self.owner_type = None
+ self.proxy_name: str = ""
+ self.name: str = ""
- def __set_name__(self, owner, name):
- self.name = name
- self.proxy_name = f"_{name}"
- self.owner_type = owner
- # Add our empty proxy field to the class.
- setattr(owner, self.proxy_name, None)
+ def __set_name__(self, owner, name):
+ self.name = name
+ self.proxy_name = f"_{name}"
+ self.owner_type = owner
+ # Add our empty proxy field to the class.
+ setattr(owner, self.proxy_name, None)
- def __get__(self, obj, objtype=None):
- return getattr(obj, self.proxy_name)
+ def __get__(self, obj, objtype=None):
+ return getattr(obj, self.proxy_name)
- def __set__(self, obj, value):
- if value is self:
- # This will happen if the dataclass uses the default value, we just
- # default to None.
- value = None
+ def __set__(self, obj, value):
+ if value is self:
+ # This will happen if the dataclass uses the default value, we just
+ # default to None.
+ value = None
- if value is not None:
- # Clear the others
- for name, oneof in IrDataclassSpecs.get_specs(
- self.owner_type
- ).oneof_mappings:
- if oneof == self.oneof and name != self.name:
- setattr(obj, name, None)
+ if value is not None:
+ # Clear the others
+ for name, oneof in IrDataclassSpecs.get_specs(
+ self.owner_type
+ ).oneof_mappings:
+ if oneof == self.oneof and name != self.name:
+ setattr(obj, name, None)
- setattr(obj, self.proxy_name, value)
+ setattr(obj, self.proxy_name, value)
def oneof_field(name: str):
- """Alternative for `datclasses.field` that sets up a oneof variable"""
- return dataclasses.field( # pylint:disable=invalid-field-call
- default=OneOfField(name), metadata={"oneof": name}, init=True
- )
+ """Alternative for `datclasses.field` that sets up a oneof variable"""
+ return dataclasses.field( # pylint:disable=invalid-field-call
+ default=OneOfField(name), metadata={"oneof": name}, init=True
+ )
def str_field():
- """Helper used to define a defaulted str field"""
- return dataclasses.field(default_factory=str) # pylint:disable=invalid-field-call
+ """Helper used to define a defaulted str field"""
+ return dataclasses.field(default_factory=str) # pylint:disable=invalid-field-call
def list_field(cls_or_fn):
- """Helper used to define a defaulted list field.
+ """Helper used to define a defaulted list field.
- A lambda can be used to defer resolution of a field type that references its
- container type, for example:
- ```
- class Foo:
- subtypes: list['Foo'] = list_field(lambda: Foo)
- names: list[str] = list_field(str)
- ```
+ A lambda can be used to defer resolution of a field type that references its
+ container type, for example:
+ ```
+ class Foo:
+ subtypes: list['Foo'] = list_field(lambda: Foo)
+ names: list[str] = list_field(str)
+ ```
- Args:
- cls_or_fn: The class type or a function that resolves to the class type.
- """
+ Args:
+ cls_or_fn: The class type or a function that resolves to the class type.
+ """
- def list_factory(c):
- return CopyValuesList(c if isinstance(c, type) else c())
+ def list_factory(c):
+ return CopyValuesList(c if isinstance(c, type) else c())
- return dataclasses.field( # pylint:disable=invalid-field-call
- default_factory=lambda: list_factory(cls_or_fn)
- )
+ return dataclasses.field( # pylint:disable=invalid-field-call
+ default_factory=lambda: list_factory(cls_or_fn)
+ )
diff --git a/compiler/util/ir_data_fields_test.py b/compiler/util/ir_data_fields_test.py
index fa99943..2853fde 100644
--- a/compiler/util/ir_data_fields_test.py
+++ b/compiler/util/ir_data_fields_test.py
@@ -25,230 +25,229 @@
class TestEnum(enum.Enum):
- """Used to test python Enum handling."""
+ """Used to test python Enum handling."""
- UNKNOWN = 0
- VALUE_1 = 1
- VALUE_2 = 2
+ UNKNOWN = 0
+ VALUE_1 = 1
+ VALUE_2 = 2
@dataclasses.dataclass
class Opaque(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers"""
@dataclasses.dataclass
class ClassWithUnion(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers"""
- opaque: Optional[Opaque] = ir_data_fields.oneof_field("type")
- integer: Optional[int] = ir_data_fields.oneof_field("type")
- boolean: Optional[bool] = ir_data_fields.oneof_field("type")
- enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type")
- non_union_field: int = 0
+ opaque: Optional[Opaque] = ir_data_fields.oneof_field("type")
+ integer: Optional[int] = ir_data_fields.oneof_field("type")
+ boolean: Optional[bool] = ir_data_fields.oneof_field("type")
+ enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type")
+ non_union_field: int = 0
@dataclasses.dataclass
class ClassWithTwoUnions(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers"""
- opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1")
- integer: Optional[int] = ir_data_fields.oneof_field("type_1")
- boolean: Optional[bool] = ir_data_fields.oneof_field("type_2")
- enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2")
- non_union_field: int = 0
- seq_field: list[int] = ir_data_fields.list_field(int)
+ opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1")
+ integer: Optional[int] = ir_data_fields.oneof_field("type_1")
+ boolean: Optional[bool] = ir_data_fields.oneof_field("type_2")
+ enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2")
+ non_union_field: int = 0
+ seq_field: list[int] = ir_data_fields.list_field(int)
@dataclasses.dataclass
class NestedClass(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers"""
- one_union_class: Optional[ClassWithUnion] = None
- two_union_class: Optional[ClassWithTwoUnions] = None
+ one_union_class: Optional[ClassWithUnion] = None
+ two_union_class: Optional[ClassWithTwoUnions] = None
@dataclasses.dataclass
class ListCopyTestClass(ir_data.Message):
- """Used to test behavior or extending a sequence."""
+ """Used to test behavior or extending a sequence."""
- non_union_field: int = 0
- seq_field: list[int] = ir_data_fields.list_field(int)
+ non_union_field: int = 0
+ seq_field: list[int] = ir_data_fields.list_field(int)
@dataclasses.dataclass
class OneofFieldTest(ir_data.Message):
- """Basic test class for oneof fields"""
+ """Basic test class for oneof fields"""
- int_field_1: Optional[int] = ir_data_fields.oneof_field("type_1")
- int_field_2: Optional[int] = ir_data_fields.oneof_field("type_1")
- normal_field: bool = True
+ int_field_1: Optional[int] = ir_data_fields.oneof_field("type_1")
+ int_field_2: Optional[int] = ir_data_fields.oneof_field("type_1")
+ normal_field: bool = True
class OneOfTest(unittest.TestCase):
- """Tests for the the various oneof field helpers"""
+ """Tests for the the various oneof field helpers"""
- def test_field_attribute(self):
- """Test the `oneof_field` helper."""
- test_field = ir_data_fields.oneof_field("type_1")
- self.assertIsNotNone(test_field)
- self.assertTrue(test_field.init)
- self.assertIsInstance(test_field.default, ir_data_fields.OneOfField)
- self.assertEqual(test_field.metadata.get("oneof"), "type_1")
+ def test_field_attribute(self):
+ """Test the `oneof_field` helper."""
+ test_field = ir_data_fields.oneof_field("type_1")
+ self.assertIsNotNone(test_field)
+ self.assertTrue(test_field.init)
+ self.assertIsInstance(test_field.default, ir_data_fields.OneOfField)
+ self.assertEqual(test_field.metadata.get("oneof"), "type_1")
- def test_init_default(self):
- """Test creating an instance with default fields"""
- one_of_field_test = OneofFieldTest()
- self.assertIsNone(one_of_field_test.int_field_1)
- self.assertIsNone(one_of_field_test.int_field_2)
- self.assertTrue(one_of_field_test.normal_field)
+ def test_init_default(self):
+ """Test creating an instance with default fields"""
+ one_of_field_test = OneofFieldTest()
+ self.assertIsNone(one_of_field_test.int_field_1)
+ self.assertIsNone(one_of_field_test.int_field_2)
+ self.assertTrue(one_of_field_test.normal_field)
- def test_init(self):
- """Test creating an instance with non-default fields"""
- one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False)
- self.assertEqual(one_of_field_test.int_field_1, 10)
- self.assertIsNone(one_of_field_test.int_field_2)
- self.assertFalse(one_of_field_test.normal_field)
+ def test_init(self):
+ """Test creating an instance with non-default fields"""
+ one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False)
+ self.assertEqual(one_of_field_test.int_field_1, 10)
+ self.assertIsNone(one_of_field_test.int_field_2)
+ self.assertFalse(one_of_field_test.normal_field)
- def test_set_oneof_field(self):
- """Tests setting oneof fields causes others in the group to be unset"""
- one_of_field_test = OneofFieldTest()
- one_of_field_test.int_field_1 = 10
- self.assertEqual(one_of_field_test.int_field_1, 10)
- self.assertEqual(one_of_field_test.int_field_2, None)
- one_of_field_test.int_field_2 = 20
- self.assertEqual(one_of_field_test.int_field_1, None)
- self.assertEqual(one_of_field_test.int_field_2, 20)
+ def test_set_oneof_field(self):
+ """Tests setting oneof fields causes others in the group to be unset"""
+ one_of_field_test = OneofFieldTest()
+ one_of_field_test.int_field_1 = 10
+ self.assertEqual(one_of_field_test.int_field_1, 10)
+ self.assertEqual(one_of_field_test.int_field_2, None)
+ one_of_field_test.int_field_2 = 20
+ self.assertEqual(one_of_field_test.int_field_1, None)
+ self.assertEqual(one_of_field_test.int_field_2, 20)
- # Do it again
- one_of_field_test.int_field_1 = 10
- self.assertEqual(one_of_field_test.int_field_1, 10)
- self.assertEqual(one_of_field_test.int_field_2, None)
- one_of_field_test.int_field_2 = 20
- self.assertEqual(one_of_field_test.int_field_1, None)
- self.assertEqual(one_of_field_test.int_field_2, 20)
+ # Do it again
+ one_of_field_test.int_field_1 = 10
+ self.assertEqual(one_of_field_test.int_field_1, 10)
+ self.assertEqual(one_of_field_test.int_field_2, None)
+ one_of_field_test.int_field_2 = 20
+ self.assertEqual(one_of_field_test.int_field_1, None)
+ self.assertEqual(one_of_field_test.int_field_2, 20)
- # Now create a new instance and make sure changes to it are not reflected
- # on the original object.
- one_of_field_test_2 = OneofFieldTest()
- one_of_field_test_2.int_field_1 = 1000
- self.assertEqual(one_of_field_test_2.int_field_1, 1000)
- self.assertEqual(one_of_field_test_2.int_field_2, None)
- self.assertEqual(one_of_field_test.int_field_1, None)
- self.assertEqual(one_of_field_test.int_field_2, 20)
+ # Now create a new instance and make sure changes to it are not reflected
+ # on the original object.
+ one_of_field_test_2 = OneofFieldTest()
+ one_of_field_test_2.int_field_1 = 1000
+ self.assertEqual(one_of_field_test_2.int_field_1, 1000)
+ self.assertEqual(one_of_field_test_2.int_field_2, None)
+ self.assertEqual(one_of_field_test.int_field_1, None)
+ self.assertEqual(one_of_field_test.int_field_2, 20)
- def test_set_to_none(self):
- """Tests explicitly setting a oneof field to None"""
- one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False)
- self.assertEqual(one_of_field_test.int_field_1, 10)
- self.assertIsNone(one_of_field_test.int_field_2)
- self.assertFalse(one_of_field_test.normal_field)
+ def test_set_to_none(self):
+ """Tests explicitly setting a oneof field to None"""
+ one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False)
+ self.assertEqual(one_of_field_test.int_field_1, 10)
+ self.assertIsNone(one_of_field_test.int_field_2)
+ self.assertFalse(one_of_field_test.normal_field)
- # Clear the set fields
- one_of_field_test.int_field_1 = None
- self.assertIsNone(one_of_field_test.int_field_1)
- self.assertIsNone(one_of_field_test.int_field_2)
- self.assertFalse(one_of_field_test.normal_field)
+ # Clear the set fields
+ one_of_field_test.int_field_1 = None
+ self.assertIsNone(one_of_field_test.int_field_1)
+ self.assertIsNone(one_of_field_test.int_field_2)
+ self.assertFalse(one_of_field_test.normal_field)
- # Set another field
- one_of_field_test.int_field_2 = 200
- self.assertIsNone(one_of_field_test.int_field_1)
- self.assertEqual(one_of_field_test.int_field_2, 200)
- self.assertFalse(one_of_field_test.normal_field)
+ # Set another field
+ one_of_field_test.int_field_2 = 200
+ self.assertIsNone(one_of_field_test.int_field_1)
+ self.assertEqual(one_of_field_test.int_field_2, 200)
+ self.assertFalse(one_of_field_test.normal_field)
- # Clear the already unset field
- one_of_field_test.int_field_1 = None
- self.assertIsNone(one_of_field_test.int_field_1)
- self.assertEqual(one_of_field_test.int_field_2, 200)
- self.assertFalse(one_of_field_test.normal_field)
+ # Clear the already unset field
+ one_of_field_test.int_field_1 = None
+ self.assertIsNone(one_of_field_test.int_field_1)
+ self.assertEqual(one_of_field_test.int_field_2, 200)
+ self.assertFalse(one_of_field_test.normal_field)
- def test_oneof_specs(self):
- """Tests the `oneof_field_specs` filter"""
- expected = {
- "int_field_1": ir_data_fields.make_field_spec(
- "int_field_1", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1"
- ),
- "int_field_2": ir_data_fields.make_field_spec(
- "int_field_2", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1"
- ),
- }
- actual = ir_data_fields.IrDataclassSpecs.get_specs(
- OneofFieldTest
- ).oneof_field_specs
- self.assertDictEqual(actual, expected)
+ def test_oneof_specs(self):
+ """Tests the `oneof_field_specs` filter"""
+ expected = {
+ "int_field_1": ir_data_fields.make_field_spec(
+ "int_field_1", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1"
+ ),
+ "int_field_2": ir_data_fields.make_field_spec(
+ "int_field_2", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1"
+ ),
+ }
+ actual = ir_data_fields.IrDataclassSpecs.get_specs(
+ OneofFieldTest
+ ).oneof_field_specs
+ self.assertDictEqual(actual, expected)
- def test_oneof_mappings(self):
- """Tests the `oneof_mappings` function"""
- expected = (("int_field_1", "type_1"), ("int_field_2", "type_1"))
- actual = ir_data_fields.IrDataclassSpecs.get_specs(
- OneofFieldTest
- ).oneof_mappings
- self.assertTupleEqual(actual, expected)
+ def test_oneof_mappings(self):
+ """Tests the `oneof_mappings` function"""
+ expected = (("int_field_1", "type_1"), ("int_field_2", "type_1"))
+ actual = ir_data_fields.IrDataclassSpecs.get_specs(
+ OneofFieldTest
+ ).oneof_mappings
+ self.assertTupleEqual(actual, expected)
class IrDataFieldsTest(unittest.TestCase):
- """Tests misc methods in ir_data_fields"""
+ """Tests misc methods in ir_data_fields"""
- def test_copy(self):
- """Tests copying a data class works as expected"""
- union = ClassWithTwoUnions(
- opaque=Opaque(), boolean=True, non_union_field=10, seq_field=[1, 2, 3]
- )
- nested_class = NestedClass(two_union_class=union)
- nested_class_copy = ir_data_fields.copy(nested_class)
- self.assertIsNotNone(nested_class_copy)
- self.assertIsNot(nested_class, nested_class_copy)
- self.assertEqual(nested_class_copy, nested_class)
+ def test_copy(self):
+ """Tests copying a data class works as expected"""
+ union = ClassWithTwoUnions(
+ opaque=Opaque(), boolean=True, non_union_field=10, seq_field=[1, 2, 3]
+ )
+ nested_class = NestedClass(two_union_class=union)
+ nested_class_copy = ir_data_fields.copy(nested_class)
+ self.assertIsNotNone(nested_class_copy)
+ self.assertIsNot(nested_class, nested_class_copy)
+ self.assertEqual(nested_class_copy, nested_class)
- empty_copy = ir_data_fields.copy(None)
- self.assertIsNone(empty_copy)
+ empty_copy = ir_data_fields.copy(None)
+ self.assertIsNone(empty_copy)
- def test_copy_values_list(self):
- """Tests that CopyValuesList copies values"""
- data_list = ir_data_fields.CopyValuesList(ListCopyTestClass)
- self.assertEqual(len(data_list), 0)
+ def test_copy_values_list(self):
+ """Tests that CopyValuesList copies values"""
+ data_list = ir_data_fields.CopyValuesList(ListCopyTestClass)
+ self.assertEqual(len(data_list), 0)
- list_test = ListCopyTestClass(non_union_field=2, seq_field=[5, 6, 7])
- list_tests = [ir_data_fields.copy(list_test) for _ in range(4)]
- data_list.extend(list_tests)
- self.assertEqual(len(data_list), 4)
- for i in data_list:
- self.assertEqual(i, list_test)
+ list_test = ListCopyTestClass(non_union_field=2, seq_field=[5, 6, 7])
+ list_tests = [ir_data_fields.copy(list_test) for _ in range(4)]
+ data_list.extend(list_tests)
+ self.assertEqual(len(data_list), 4)
+ for i in data_list:
+ self.assertEqual(i, list_test)
- def test_list_param_is_copied(self):
- """Test that lists passed to constructors are converted to CopyValuesList"""
- seq_field = [5, 6, 7]
- list_test = ListCopyTestClass(non_union_field=2, seq_field=seq_field)
- self.assertEqual(len(list_test.seq_field), len(seq_field))
- self.assertIsNot(list_test.seq_field, seq_field)
- self.assertEqual(list_test.seq_field, seq_field)
- self.assertTrue(
- isinstance(list_test.seq_field, ir_data_fields.CopyValuesList)
- )
+ def test_list_param_is_copied(self):
+ """Test that lists passed to constructors are converted to CopyValuesList"""
+ seq_field = [5, 6, 7]
+ list_test = ListCopyTestClass(non_union_field=2, seq_field=seq_field)
+ self.assertEqual(len(list_test.seq_field), len(seq_field))
+ self.assertIsNot(list_test.seq_field, seq_field)
+ self.assertEqual(list_test.seq_field, seq_field)
+ self.assertTrue(isinstance(list_test.seq_field, ir_data_fields.CopyValuesList))
- def test_copy_oneof(self):
- """Tests copying an IR data class that has oneof fields."""
- oneof_test = OneofFieldTest()
- oneof_test.int_field_1 = 10
- oneof_test.normal_field = False
- self.assertEqual(oneof_test.int_field_1, 10)
- self.assertEqual(oneof_test.normal_field, False)
+ def test_copy_oneof(self):
+ """Tests copying an IR data class that has oneof fields."""
+ oneof_test = OneofFieldTest()
+ oneof_test.int_field_1 = 10
+ oneof_test.normal_field = False
+ self.assertEqual(oneof_test.int_field_1, 10)
+ self.assertEqual(oneof_test.normal_field, False)
- oneof_copy = ir_data_fields.copy(oneof_test)
- self.assertIsNotNone(oneof_copy)
- self.assertEqual(oneof_copy.int_field_1, 10)
- self.assertIsNone(oneof_copy.int_field_2)
- self.assertEqual(oneof_copy.normal_field, False)
+ oneof_copy = ir_data_fields.copy(oneof_test)
+ self.assertIsNotNone(oneof_copy)
+ self.assertEqual(oneof_copy.int_field_1, 10)
+ self.assertIsNone(oneof_copy.int_field_2)
+ self.assertEqual(oneof_copy.normal_field, False)
- oneof_copy.int_field_2 = 100
- self.assertEqual(oneof_copy.int_field_2, 100)
- self.assertIsNone(oneof_copy.int_field_1)
- self.assertEqual(oneof_test.int_field_1, 10)
- self.assertEqual(oneof_test.normal_field, False)
+ oneof_copy.int_field_2 = 100
+ self.assertEqual(oneof_copy.int_field_2, 100)
+ self.assertIsNone(oneof_copy.int_field_1)
+ self.assertEqual(oneof_test.int_field_1, 10)
+ self.assertEqual(oneof_test.normal_field, False)
ir_data_fields.cache_message_specs(
- sys.modules[OneofFieldTest.__module__], ir_data.Message)
+ sys.modules[OneofFieldTest.__module__], ir_data.Message
+)
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py
index 37cc9a3..2d38eac 100644
--- a/compiler/util/ir_data_utils.py
+++ b/compiler/util/ir_data_utils.py
@@ -85,383 +85,382 @@
def field_specs(ir: Union[MessageT, type[MessageT]]):
- """Retrieves the field specs for the IR data class"""
- data_type = ir if isinstance(ir, type) else type(ir)
- return ir_data_fields.IrDataclassSpecs.get_specs(data_type).all_field_specs
+ """Retrieves the field specs for the IR data class"""
+ data_type = ir if isinstance(ir, type) else type(ir)
+ return ir_data_fields.IrDataclassSpecs.get_specs(data_type).all_field_specs
class IrDataSerializer:
- """Provides methods for serializing IR data objects"""
+ """Provides methods for serializing IR data objects"""
- def __init__(self, ir: MessageT):
- assert ir is not None
- self.ir = ir
+ def __init__(self, ir: MessageT):
+ assert ir is not None
+ self.ir = ir
- def _to_dict(
- self,
- ir: MessageT,
- field_func: Callable[
- [MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]]
- ],
- ) -> MutableMapping[str, Any]:
- assert ir is not None
- values: MutableMapping[str, Any] = {}
- for spec, value in field_func(ir):
- if value is not None and spec.is_dataclass:
- if spec.is_sequence:
- value = [self._to_dict(v, field_func) for v in value]
- else:
- value = self._to_dict(value, field_func)
- values[spec.name] = value
- return values
+ def _to_dict(
+ self,
+ ir: MessageT,
+ field_func: Callable[[MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]]],
+ ) -> MutableMapping[str, Any]:
+ assert ir is not None
+ values: MutableMapping[str, Any] = {}
+ for spec, value in field_func(ir):
+ if value is not None and spec.is_dataclass:
+ if spec.is_sequence:
+ value = [self._to_dict(v, field_func) for v in value]
+ else:
+ value = self._to_dict(value, field_func)
+ values[spec.name] = value
+ return values
- def to_dict(self, exclude_none: bool = False):
- """Converts the IR data class to a dictionary."""
+ def to_dict(self, exclude_none: bool = False):
+ """Converts the IR data class to a dictionary."""
- def non_empty(ir):
- return fields_and_values(
- ir, lambda v: v is not None and (not isinstance(v, list) or len(v))
- )
-
- def all_fields(ir):
- return fields_and_values(ir)
-
- # It's tempting to use `dataclasses.asdict` here, but that does a deep
- # copy which is overkill for the current usage; mainly as an intermediary
- # for `to_json` and `repr`.
- return self._to_dict(self.ir, non_empty if exclude_none else all_fields)
-
- def to_json(self, *args, **kwargs):
- """Converts the IR data class to a JSON string"""
- return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs)
-
- @staticmethod
- def from_json(data_cls, data):
- """Constructs an IR data class from the given JSON string"""
- as_dict = json.loads(data)
- return IrDataSerializer.from_dict(data_cls, as_dict)
-
- def copy_from_dict(self, data):
- """Deserializes the data and overwrites the IR data class with it"""
- cls = type(self.ir)
- data_copy = IrDataSerializer.from_dict(cls, data)
- for k in field_specs(cls):
- setattr(self.ir, k, getattr(data_copy, k))
-
- @staticmethod
- def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum:
- if isinstance(val, str):
- return getattr(enum_cls, val)
- return enum_cls(val)
-
- @staticmethod
- def _enum_type_hook(enum_cls: type[enum.Enum]):
- return lambda val: IrDataSerializer._enum_type_converter(enum_cls, val)
-
- @staticmethod
- def _from_dict(data_cls: type[MessageT], data):
- class_fields: MutableMapping[str, Any] = {}
- for name, spec in ir_data_fields.field_specs(data_cls).items():
- if (value := data.get(name)) is not None:
- if spec.is_dataclass:
- if spec.is_sequence:
- class_fields[name] = [
- IrDataSerializer._from_dict(spec.data_type, v) for v in value
- ]
- else:
- class_fields[name] = IrDataSerializer._from_dict(
- spec.data_type, value
+ def non_empty(ir):
+ return fields_and_values(
+ ir, lambda v: v is not None and (not isinstance(v, list) or len(v))
)
- else:
- if spec.data_type in (
- ir_data.FunctionMapping,
- ir_data.AddressableUnit,
- ):
- class_fields[name] = IrDataSerializer._enum_type_converter(
- spec.data_type, value
- )
- else:
- if spec.is_sequence:
- class_fields[name] = value
- else:
- class_fields[name] = spec.data_type(value)
- return data_cls(**class_fields)
- @staticmethod
- def from_dict(data_cls: type[MessageT], data):
- """Creates a new IR data instance from a serialized dict"""
- return IrDataSerializer._from_dict(data_cls, data)
+ def all_fields(ir):
+ return fields_and_values(ir)
+
+ # It's tempting to use `dataclasses.asdict` here, but that does a deep
+ # copy which is overkill for the current usage; mainly as an intermediary
+ # for `to_json` and `repr`.
+ return self._to_dict(self.ir, non_empty if exclude_none else all_fields)
+
+ def to_json(self, *args, **kwargs):
+ """Converts the IR data class to a JSON string"""
+ return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs)
+
+ @staticmethod
+ def from_json(data_cls, data):
+ """Constructs an IR data class from the given JSON string"""
+ as_dict = json.loads(data)
+ return IrDataSerializer.from_dict(data_cls, as_dict)
+
+ def copy_from_dict(self, data):
+ """Deserializes the data and overwrites the IR data class with it"""
+ cls = type(self.ir)
+ data_copy = IrDataSerializer.from_dict(cls, data)
+ for k in field_specs(cls):
+ setattr(self.ir, k, getattr(data_copy, k))
+
+ @staticmethod
+ def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum:
+ if isinstance(val, str):
+ return getattr(enum_cls, val)
+ return enum_cls(val)
+
+ @staticmethod
+ def _enum_type_hook(enum_cls: type[enum.Enum]):
+ return lambda val: IrDataSerializer._enum_type_converter(enum_cls, val)
+
+ @staticmethod
+ def _from_dict(data_cls: type[MessageT], data):
+ class_fields: MutableMapping[str, Any] = {}
+ for name, spec in ir_data_fields.field_specs(data_cls).items():
+ if (value := data.get(name)) is not None:
+ if spec.is_dataclass:
+ if spec.is_sequence:
+ class_fields[name] = [
+ IrDataSerializer._from_dict(spec.data_type, v)
+ for v in value
+ ]
+ else:
+ class_fields[name] = IrDataSerializer._from_dict(
+ spec.data_type, value
+ )
+ else:
+ if spec.data_type in (
+ ir_data.FunctionMapping,
+ ir_data.AddressableUnit,
+ ):
+ class_fields[name] = IrDataSerializer._enum_type_converter(
+ spec.data_type, value
+ )
+ else:
+ if spec.is_sequence:
+ class_fields[name] = value
+ else:
+ class_fields[name] = spec.data_type(value)
+ return data_cls(**class_fields)
+
+ @staticmethod
+ def from_dict(data_cls: type[MessageT], data):
+ """Creates a new IR data instance from a serialized dict"""
+ return IrDataSerializer._from_dict(data_cls, data)
class _IrDataSequenceBuilder(MutableSequence[MessageT]):
- """Wrapper for a list of IR elements
+ """Wrapper for a list of IR elements
- Simply wraps the returned values during indexed access and iteration with
- IrDataBuilders.
- """
+ Simply wraps the returned values during indexed access and iteration with
+ IrDataBuilders.
+ """
- def __init__(self, target: MutableSequence[MessageT]):
- self._target = target
+ def __init__(self, target: MutableSequence[MessageT]):
+ self._target = target
- def __delitem__(self, key):
- del self._target[key]
+ def __delitem__(self, key):
+ del self._target[key]
- def __getitem__(self, key):
- return _IrDataBuilder(self._target.__getitem__(key))
+ def __getitem__(self, key):
+ return _IrDataBuilder(self._target.__getitem__(key))
- def __setitem__(self, key, value):
- self._target[key] = value
+ def __setitem__(self, key, value):
+ self._target[key] = value
- def __iter__(self):
- itr = iter(self._target)
- for i in itr:
- yield _IrDataBuilder(i)
+ def __iter__(self):
+ itr = iter(self._target)
+ for i in itr:
+ yield _IrDataBuilder(i)
- def __repr__(self):
- return repr(self._target)
+ def __repr__(self):
+ return repr(self._target)
- def __len__(self):
- return len(self._target)
+ def __len__(self):
+ return len(self._target)
- def __eq__(self, other):
- return self._target == other
+ def __eq__(self, other):
+ return self._target == other
- def __ne__(self, other):
- return self._target != other
+ def __ne__(self, other):
+ return self._target != other
- def insert(self, index, value):
- self._target.insert(index, value)
+ def insert(self, index, value):
+ self._target.insert(index, value)
- def extend(self, values):
- self._target.extend(values)
+ def extend(self, values):
+ self._target.extend(values)
class _IrDataBuilder(Generic[MessageT]):
- """Wrapper for an IR element"""
+ """Wrapper for an IR element"""
- def __init__(self, ir: MessageT) -> None:
- assert ir is not None
- self.ir: MessageT = ir
+ def __init__(self, ir: MessageT) -> None:
+ assert ir is not None
+ self.ir: MessageT = ir
- def __setattr__(self, __name: str, __value: Any) -> None:
- if __name == "ir":
- # This our proxy object
- object.__setattr__(self, __name, __value)
- else:
- # Passthrough to the proxy object
- ir: MessageT = object.__getattribute__(self, "ir")
- setattr(ir, __name, __value)
+ def __setattr__(self, __name: str, __value: Any) -> None:
+ if __name == "ir":
+ # This our proxy object
+ object.__setattr__(self, __name, __value)
+ else:
+ # Passthrough to the proxy object
+ ir: MessageT = object.__getattribute__(self, "ir")
+ setattr(ir, __name, __value)
- def __getattribute__(self, name: str) -> Any:
- """Hook for `getattr` that handles adding missing fields.
+ def __getattribute__(self, name: str) -> Any:
+ """Hook for `getattr` that handles adding missing fields.
- If the field is missing inserts it, and then returns either the raw value
- for basic types
- or a new IrBuilder wrapping the field to handle the next field access in a
- longer chain.
- """
+ If the field is missing inserts it, and then returns either the raw value
+ for basic types
+ or a new IrBuilder wrapping the field to handle the next field access in a
+ longer chain.
+ """
- # Check if getting one of the builder attributes
- if name in ("CopyFrom", "ir"):
- return object.__getattribute__(self, name)
+ # Check if getting one of the builder attributes
+ if name in ("CopyFrom", "ir"):
+ return object.__getattribute__(self, name)
- # Get our target object by bypassing our getattr hook
- ir: MessageT = object.__getattribute__(self, "ir")
- if ir is None:
- return object.__getattribute__(self, name)
+ # Get our target object by bypassing our getattr hook
+ ir: MessageT = object.__getattribute__(self, "ir")
+ if ir is None:
+ return object.__getattribute__(self, name)
- if name in ("HasField", "WhichOneof"):
- return getattr(ir, name)
+ if name in ("HasField", "WhichOneof"):
+ return getattr(ir, name)
- field_spec = field_specs(ir).get(name)
- if field_spec is None:
- raise AttributeError(
- f"No field {name} on {type(ir).__module__}.{type(ir).__name__}."
- )
+ field_spec = field_specs(ir).get(name)
+ if field_spec is None:
+ raise AttributeError(
+ f"No field {name} on {type(ir).__module__}.{type(ir).__name__}."
+ )
- obj = getattr(ir, name, None)
- if obj is None:
- # Create a default and store it
- obj = ir_data_fields.build_default(field_spec)
- setattr(ir, name, obj)
+ obj = getattr(ir, name, None)
+ if obj is None:
+ # Create a default and store it
+ obj = ir_data_fields.build_default(field_spec)
+ setattr(ir, name, obj)
- if field_spec.is_dataclass:
- obj = (
- _IrDataSequenceBuilder(obj)
- if field_spec.is_sequence
- else _IrDataBuilder(obj)
- )
+ if field_spec.is_dataclass:
+ obj = (
+ _IrDataSequenceBuilder(obj)
+ if field_spec.is_sequence
+ else _IrDataBuilder(obj)
+ )
- return obj
+ return obj
- def CopyFrom(self, template: MessageT): # pylint:disable=invalid-name
- """Updates the fields of this class with values set in the template"""
- update(cast(type[MessageT], self), template)
+ def CopyFrom(self, template: MessageT): # pylint:disable=invalid-name
+ """Updates the fields of this class with values set in the template"""
+ update(cast(type[MessageT], self), template)
def builder(target: MessageT) -> MessageT:
- """Create a wrapper around the target to help build an IR Data structure"""
- # Check if the target is already a builder.
- if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)):
- return target
+ """Create a wrapper around the target to help build an IR Data structure"""
+ # Check if the target is already a builder.
+ if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)):
+ return target
- # Builders are only valid for IR data classes.
- if not hasattr(type(target), "IR_DATACLASS"):
- raise TypeError(f"Builder target {type(target)} is not an ir_data.message")
+ # Builders are only valid for IR data classes.
+ if not hasattr(type(target), "IR_DATACLASS"):
+ raise TypeError(f"Builder target {type(target)} is not an ir_data.message")
- # Create a builder and cast it to the target type to expose type hinting for
- # the wrapped type.
- return cast(MessageT, _IrDataBuilder(target))
+ # Create a builder and cast it to the target type to expose type hinting for
+ # the wrapped type.
+ return cast(MessageT, _IrDataBuilder(target))
def _field_checker_from_spec(spec: ir_data_fields.FieldSpec):
- """Helper that builds an FieldChecker that pretends to be an IR class"""
- if spec.is_sequence:
- return []
- if spec.is_dataclass:
- return _ReadOnlyFieldChecker(spec)
- return ir_data_fields.build_default(spec)
+ """Helper that builds an FieldChecker that pretends to be an IR class"""
+ if spec.is_sequence:
+ return []
+ if spec.is_dataclass:
+ return _ReadOnlyFieldChecker(spec)
+ return ir_data_fields.build_default(spec)
def _field_type(ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> type:
- if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
- return ir_or_spec.data_type
- return type(ir_or_spec)
+ if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
+ return ir_or_spec.data_type
+ return type(ir_or_spec)
class _ReadOnlyFieldChecker:
- """Class used the chain calls to fields that aren't set"""
+ """Class used the chain calls to fields that aren't set"""
- def __init__(self, ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> None:
- self.ir_or_spec = ir_or_spec
+ def __init__(self, ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> None:
+ self.ir_or_spec = ir_or_spec
- def __setattr__(self, name: str, value: Any) -> None:
- if name == "ir_or_spec":
- return object.__setattr__(self, name, value)
+ def __setattr__(self, name: str, value: Any) -> None:
+ if name == "ir_or_spec":
+ return object.__setattr__(self, name, value)
- raise AttributeError(f"Cannot set {name} on read-only wrapper")
+ raise AttributeError(f"Cannot set {name} on read-only wrapper")
- def __getattribute__(self, name: str) -> Any: # pylint:disable=too-many-return-statements
- ir_or_spec = object.__getattribute__(self, "ir_or_spec")
- if name == "ir_or_spec":
- return ir_or_spec
+ def __getattribute__(
+ self, name: str
+ ) -> Any: # pylint:disable=too-many-return-statements
+ ir_or_spec = object.__getattribute__(self, "ir_or_spec")
+ if name == "ir_or_spec":
+ return ir_or_spec
- field_type = _field_type(ir_or_spec)
- spec = field_specs(field_type).get(name)
- if not spec:
- if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
- if name == "HasField":
- return lambda x: False
- if name == "WhichOneof":
- return lambda x: None
- return object.__getattribute__(ir_or_spec, name)
+ field_type = _field_type(ir_or_spec)
+ spec = field_specs(field_type).get(name)
+ if not spec:
+ if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
+ if name == "HasField":
+ return lambda x: False
+ if name == "WhichOneof":
+ return lambda x: None
+ return object.__getattribute__(ir_or_spec, name)
- if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
- # Just pretending
- return _field_checker_from_spec(spec)
+ if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
+ # Just pretending
+ return _field_checker_from_spec(spec)
- value = getattr(ir_or_spec, name)
- if value is None:
- return _field_checker_from_spec(spec)
+ value = getattr(ir_or_spec, name)
+ if value is None:
+ return _field_checker_from_spec(spec)
- if spec.is_dataclass:
- if spec.is_sequence:
- return [_ReadOnlyFieldChecker(i) for i in value]
- return _ReadOnlyFieldChecker(value)
+ if spec.is_dataclass:
+ if spec.is_sequence:
+ return [_ReadOnlyFieldChecker(i) for i in value]
+ return _ReadOnlyFieldChecker(value)
- return value
+ return value
- def __eq__(self, other):
- if isinstance(other, _ReadOnlyFieldChecker):
- other = other.ir_or_spec
- return self.ir_or_spec == other
+ def __eq__(self, other):
+ if isinstance(other, _ReadOnlyFieldChecker):
+ other = other.ir_or_spec
+ return self.ir_or_spec == other
- def __ne__(self, other):
- return not self == other
+ def __ne__(self, other):
+ return not self == other
def reader(obj: Union[MessageT, _ReadOnlyFieldChecker]) -> MessageT:
- """Builds a read-only wrapper that can be used to check chains of possibly
- unset fields.
+ """Builds a read-only wrapper that can be used to check chains of possibly
+ unset fields.
- This wrapper explicitly does not alter the wrapped object and is only
- intended for reading contents.
+ This wrapper explicitly does not alter the wrapped object and is only
+ intended for reading contents.
- For example, a `reader` lets you do:
- ```
- def get_function_name_end_column(function: ir_data.Function):
- return reader(function).function_name.source_location.end.column
- ```
+ For example, a `reader` lets you do:
+ ```
+ def get_function_name_end_column(function: ir_data.Function):
+ return reader(function).function_name.source_location.end.column
+ ```
- Instead of:
- ```
- def get_function_name_end_column(function: ir_data.Function):
- if function.function_name:
- if function.function_name.source_location:
- if function.function_name.source_location.end:
- return function.function_name.source_location.end.column
- return 0
- ```
- """
- # Create a read-only wrapper if it's not already one.
- if not isinstance(obj, _ReadOnlyFieldChecker):
- obj = _ReadOnlyFieldChecker(obj)
+ Instead of:
+ ```
+ def get_function_name_end_column(function: ir_data.Function):
+ if function.function_name:
+ if function.function_name.source_location:
+ if function.function_name.source_location.end:
+ return function.function_name.source_location.end.column
+ return 0
+ ```
+ """
+ # Create a read-only wrapper if it's not already one.
+ if not isinstance(obj, _ReadOnlyFieldChecker):
+ obj = _ReadOnlyFieldChecker(obj)
- # Cast it back to the original type.
- return cast(MessageT, obj)
+ # Cast it back to the original type.
+ return cast(MessageT, obj)
def _extract_ir(
ir_or_wrapper: Union[MessageT, _ReadOnlyFieldChecker, _IrDataBuilder, None],
) -> Optional[ir_data_fields.IrDataclassInstance]:
- if isinstance(ir_or_wrapper, _ReadOnlyFieldChecker):
- ir_or_spec = ir_or_wrapper.ir_or_spec
- if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
- # This is a placeholder entry, no fields are set.
- return None
- ir_or_wrapper = ir_or_spec
- elif isinstance(ir_or_wrapper, _IrDataBuilder):
- ir_or_wrapper = ir_or_wrapper.ir
- return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper)
+ if isinstance(ir_or_wrapper, _ReadOnlyFieldChecker):
+ ir_or_spec = ir_or_wrapper.ir_or_spec
+ if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
+ # This is a placeholder entry, no fields are set.
+ return None
+ ir_or_wrapper = ir_or_spec
+ elif isinstance(ir_or_wrapper, _IrDataBuilder):
+ ir_or_wrapper = ir_or_wrapper.ir
+ return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper)
def fields_and_values(
ir_wrapper: Union[MessageT, _ReadOnlyFieldChecker],
value_filt: Optional[Callable[[Any], bool]] = None,
) -> list[Tuple[ir_data_fields.FieldSpec, Any]]:
- """Retrieves the fields and their values for a given IR data class.
+ """Retrieves the fields and their values for a given IR data class.
- Args:
- ir: The IR data class or a read-only wrapper of an IR data class.
- value_filt: Optional filter used to exclude values.
- """
- if (ir := _extract_ir(ir_wrapper)) is None:
- return []
+ Args:
+ ir: The IR data class or a read-only wrapper of an IR data class.
+ value_filt: Optional filter used to exclude values.
+ """
+ if (ir := _extract_ir(ir_wrapper)) is None:
+ return []
- return ir_data_fields.fields_and_values(ir, value_filt)
+ return ir_data_fields.fields_and_values(ir, value_filt)
def get_set_fields(ir: MessageT):
- """Retrieves the field spec and value of fields that are set in the given IR data class.
+ """Retrieves the field spec and value of fields that are set in the given IR data class.
- A value is considered "set" if it is not None.
- """
- return fields_and_values(ir, lambda v: v is not None)
+ A value is considered "set" if it is not None.
+ """
+ return fields_and_values(ir, lambda v: v is not None)
def copy(ir_wrapper: Optional[MessageT]) -> Optional[MessageT]:
- """Creates a copy of the given IR data class"""
- if (ir := _extract_ir(ir_wrapper)) is None:
- return None
- ir_copy = ir_data_fields.copy(ir)
- return cast(MessageT, ir_copy)
+ """Creates a copy of the given IR data class"""
+ if (ir := _extract_ir(ir_wrapper)) is None:
+ return None
+ ir_copy = ir_data_fields.copy(ir)
+ return cast(MessageT, ir_copy)
def update(ir: MessageT, template: MessageT):
- """Updates `ir`s fields with all set fields in the template."""
- if not (template_ir := _extract_ir(template)):
- return
+ """Updates `ir`s fields with all set fields in the template."""
+ if not (template_ir := _extract_ir(template)):
+ return
- ir_data_fields.update(
- cast(ir_data_fields.IrDataclassInstance, ir), template_ir
- )
+ ir_data_fields.update(cast(ir_data_fields.IrDataclassInstance, ir), template_ir)
diff --git a/compiler/util/ir_data_utils_test.py b/compiler/util/ir_data_utils_test.py
index 1c000ec..82baac6 100644
--- a/compiler/util/ir_data_utils_test.py
+++ b/compiler/util/ir_data_utils_test.py
@@ -26,625 +26,608 @@
class TestEnum(enum.Enum):
- """Used to test python Enum handling."""
+ """Used to test python Enum handling."""
- UNKNOWN = 0
- VALUE_1 = 1
- VALUE_2 = 2
+ UNKNOWN = 0
+ VALUE_1 = 1
+ VALUE_2 = 2
@dataclasses.dataclass
class Opaque(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers"""
@dataclasses.dataclass
class ClassWithUnion(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers"""
- opaque: Optional[Opaque] = ir_data_fields.oneof_field("type")
- integer: Optional[int] = ir_data_fields.oneof_field("type")
- boolean: Optional[bool] = ir_data_fields.oneof_field("type")
- enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type")
- non_union_field: int = 0
+ opaque: Optional[Opaque] = ir_data_fields.oneof_field("type")
+ integer: Optional[int] = ir_data_fields.oneof_field("type")
+ boolean: Optional[bool] = ir_data_fields.oneof_field("type")
+ enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type")
+ non_union_field: int = 0
@dataclasses.dataclass
class ClassWithTwoUnions(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers"""
- opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1")
- integer: Optional[int] = ir_data_fields.oneof_field("type_1")
- boolean: Optional[bool] = ir_data_fields.oneof_field("type_2")
- enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2")
- non_union_field: int = 0
- seq_field: list[int] = ir_data_fields.list_field(int)
+ opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1")
+ integer: Optional[int] = ir_data_fields.oneof_field("type_1")
+ boolean: Optional[bool] = ir_data_fields.oneof_field("type_2")
+ enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2")
+ non_union_field: int = 0
+ seq_field: list[int] = ir_data_fields.list_field(int)
class IrDataUtilsTest(unittest.TestCase):
- """Tests for the miscellaneous utility functions in ir_data_utils.py."""
+ """Tests for the miscellaneous utility functions in ir_data_utils.py."""
- def test_field_specs(self):
- """Tests the `field_specs` method"""
- fields = ir_data_utils.field_specs(ir_data.TypeDefinition)
- self.assertIsNotNone(fields)
- expected_fields = (
- "external",
- "enumeration",
- "structure",
- "name",
- "attribute",
- "documentation",
- "subtype",
- "addressable_unit",
- "runtime_parameter",
- "source_location",
- )
- self.assertEqual(len(fields), len(expected_fields))
- field_names = fields.keys()
- for k in expected_fields:
- self.assertIn(k, field_names)
+ def test_field_specs(self):
+ """Tests the `field_specs` method"""
+ fields = ir_data_utils.field_specs(ir_data.TypeDefinition)
+ self.assertIsNotNone(fields)
+ expected_fields = (
+ "external",
+ "enumeration",
+ "structure",
+ "name",
+ "attribute",
+ "documentation",
+ "subtype",
+ "addressable_unit",
+ "runtime_parameter",
+ "source_location",
+ )
+ self.assertEqual(len(fields), len(expected_fields))
+ field_names = fields.keys()
+ for k in expected_fields:
+ self.assertIn(k, field_names)
- # Try a sequence
- expected_field = ir_data_fields.make_field_spec(
- "attribute", ir_data.Attribute, ir_data_fields.FieldContainer.LIST, None
- )
- self.assertEqual(fields["attribute"], expected_field)
+ # Try a sequence
+ expected_field = ir_data_fields.make_field_spec(
+ "attribute", ir_data.Attribute, ir_data_fields.FieldContainer.LIST, None
+ )
+ self.assertEqual(fields["attribute"], expected_field)
- # Try a scalar
- expected_field = ir_data_fields.make_field_spec(
- "addressable_unit",
- ir_data.AddressableUnit,
- ir_data_fields.FieldContainer.OPTIONAL,
- None,
- )
- self.assertEqual(fields["addressable_unit"], expected_field)
+ # Try a scalar
+ expected_field = ir_data_fields.make_field_spec(
+ "addressable_unit",
+ ir_data.AddressableUnit,
+ ir_data_fields.FieldContainer.OPTIONAL,
+ None,
+ )
+ self.assertEqual(fields["addressable_unit"], expected_field)
- # Try a IR data class
- expected_field = ir_data_fields.make_field_spec(
- "source_location",
- ir_data.Location,
- ir_data_fields.FieldContainer.OPTIONAL,
- None,
- )
- self.assertEqual(fields["source_location"], expected_field)
+ # Try a IR data class
+ expected_field = ir_data_fields.make_field_spec(
+ "source_location",
+ ir_data.Location,
+ ir_data_fields.FieldContainer.OPTIONAL,
+ None,
+ )
+ self.assertEqual(fields["source_location"], expected_field)
- # Try an oneof field
- expected_field = ir_data_fields.make_field_spec(
- "external",
- ir_data.External,
- ir_data_fields.FieldContainer.OPTIONAL,
- oneof="type",
- )
- self.assertEqual(fields["external"], expected_field)
+ # Try an oneof field
+ expected_field = ir_data_fields.make_field_spec(
+ "external",
+ ir_data.External,
+ ir_data_fields.FieldContainer.OPTIONAL,
+ oneof="type",
+ )
+ self.assertEqual(fields["external"], expected_field)
- # Try non-optional scalar
- fields = ir_data_utils.field_specs(ir_data.Position)
- expected_field = ir_data_fields.make_field_spec(
- "line", int, ir_data_fields.FieldContainer.NONE, None
- )
- self.assertEqual(fields["line"], expected_field)
+ # Try non-optional scalar
+ fields = ir_data_utils.field_specs(ir_data.Position)
+ expected_field = ir_data_fields.make_field_spec(
+ "line", int, ir_data_fields.FieldContainer.NONE, None
+ )
+ self.assertEqual(fields["line"], expected_field)
- fields = ir_data_utils.field_specs(ir_data.ArrayType)
- expected_field = ir_data_fields.make_field_spec(
- "base_type", ir_data.Type, ir_data_fields.FieldContainer.OPTIONAL, None
- )
- self.assertEqual(fields["base_type"], expected_field)
+ fields = ir_data_utils.field_specs(ir_data.ArrayType)
+ expected_field = ir_data_fields.make_field_spec(
+ "base_type", ir_data.Type, ir_data_fields.FieldContainer.OPTIONAL, None
+ )
+ self.assertEqual(fields["base_type"], expected_field)
- def test_is_sequence(self):
- """Tests for the `FieldSpec.is_sequence` helper"""
- type_def = ir_data.TypeDefinition(
- attribute=[
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil"),
- ),
- ]
- )
- fields = ir_data_utils.field_specs(ir_data.TypeDefinition)
- # Test against a repeated field
- self.assertTrue(fields["attribute"].is_sequence)
- # Test against a nested IR data type
- self.assertFalse(fields["name"].is_sequence)
- # Test against a plain scalar type
- fields = ir_data_utils.field_specs(type_def.attribute[0])
- self.assertFalse(fields["is_default"].is_sequence)
+ def test_is_sequence(self):
+ """Tests for the `FieldSpec.is_sequence` helper"""
+ type_def = ir_data.TypeDefinition(
+ attribute=[
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ ),
+ ]
+ )
+ fields = ir_data_utils.field_specs(ir_data.TypeDefinition)
+ # Test against a repeated field
+ self.assertTrue(fields["attribute"].is_sequence)
+ # Test against a nested IR data type
+ self.assertFalse(fields["name"].is_sequence)
+ # Test against a plain scalar type
+ fields = ir_data_utils.field_specs(type_def.attribute[0])
+ self.assertFalse(fields["is_default"].is_sequence)
- def test_is_dataclass(self):
- """Tests FieldSpec.is_dataclass against ir_data"""
- type_def = ir_data.TypeDefinition(
- attribute=[
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil"),
- ),
- ]
- )
- fields = ir_data_utils.field_specs(ir_data.TypeDefinition)
- # Test against a repeated field that holds IR data structs
- self.assertTrue(fields["attribute"].is_dataclass)
- # Test against a nested IR data type
- self.assertTrue(fields["name"].is_dataclass)
- # Test against a plain scalar type
- fields = ir_data_utils.field_specs(type_def.attribute[0])
- self.assertFalse(fields["is_default"].is_dataclass)
- # Test against a repeated field that holds scalars
- fields = ir_data_utils.field_specs(ir_data.Structure)
- self.assertFalse(fields["fields_in_dependency_order"].is_dataclass)
+ def test_is_dataclass(self):
+ """Tests FieldSpec.is_dataclass against ir_data"""
+ type_def = ir_data.TypeDefinition(
+ attribute=[
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ ),
+ ]
+ )
+ fields = ir_data_utils.field_specs(ir_data.TypeDefinition)
+ # Test against a repeated field that holds IR data structs
+ self.assertTrue(fields["attribute"].is_dataclass)
+ # Test against a nested IR data type
+ self.assertTrue(fields["name"].is_dataclass)
+ # Test against a plain scalar type
+ fields = ir_data_utils.field_specs(type_def.attribute[0])
+ self.assertFalse(fields["is_default"].is_dataclass)
+ # Test against a repeated field that holds scalars
+ fields = ir_data_utils.field_specs(ir_data.Structure)
+ self.assertFalse(fields["fields_in_dependency_order"].is_dataclass)
- def test_get_set_fields(self):
- """Tests that get set fields works"""
- type_def = ir_data.TypeDefinition(
- attribute=[
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil"),
- ),
- ]
- )
- set_fields = ir_data_utils.get_set_fields(type_def)
- expected_fields = set(
- ["attribute", "documentation", "subtype", "runtime_parameter"]
- )
- self.assertEqual(len(set_fields), len(expected_fields))
- found_fields = set()
- for k, v in set_fields:
- self.assertIn(k.name, expected_fields)
- found_fields.add(k.name)
- self.assertEqual(v, getattr(type_def, k.name))
+ def test_get_set_fields(self):
+ """Tests that get set fields works"""
+ type_def = ir_data.TypeDefinition(
+ attribute=[
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ ),
+ ]
+ )
+ set_fields = ir_data_utils.get_set_fields(type_def)
+ expected_fields = set(
+ ["attribute", "documentation", "subtype", "runtime_parameter"]
+ )
+ self.assertEqual(len(set_fields), len(expected_fields))
+ found_fields = set()
+ for k, v in set_fields:
+ self.assertIn(k.name, expected_fields)
+ found_fields.add(k.name)
+ self.assertEqual(v, getattr(type_def, k.name))
- self.assertSetEqual(found_fields, expected_fields)
+ self.assertSetEqual(found_fields, expected_fields)
- def test_copy(self):
- """Tests the `copy` helper"""
- attribute = ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil"),
- )
- attribute_copy = ir_data_utils.copy(attribute)
+ def test_copy(self):
+ """Tests the `copy` helper"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+ attribute_copy = ir_data_utils.copy(attribute)
- # Should be equivalent
- self.assertEqual(attribute, attribute_copy)
- # But not the same instance
- self.assertIsNot(attribute, attribute_copy)
+ # Should be equivalent
+ self.assertEqual(attribute, attribute_copy)
+ # But not the same instance
+ self.assertIsNot(attribute, attribute_copy)
- # Let's do a sequence
- type_def = ir_data.TypeDefinition(attribute=[attribute])
- type_def_copy = ir_data_utils.copy(type_def)
+ # Let's do a sequence
+ type_def = ir_data.TypeDefinition(attribute=[attribute])
+ type_def_copy = ir_data_utils.copy(type_def)
- # Should be equivalent
- self.assertEqual(type_def, type_def_copy)
- # But not the same instance
- self.assertIsNot(type_def, type_def_copy)
- self.assertIsNot(type_def.attribute, type_def_copy.attribute)
+ # Should be equivalent
+ self.assertEqual(type_def, type_def_copy)
+ # But not the same instance
+ self.assertIsNot(type_def, type_def_copy)
+ self.assertIsNot(type_def.attribute, type_def_copy.attribute)
- def test_update(self):
- """Tests the `update` helper"""
- attribute_template = ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil"),
- )
- attribute = ir_data.Attribute(is_default=True)
- ir_data_utils.update(attribute, attribute_template)
- self.assertIsNotNone(attribute.value)
- self.assertIsNot(attribute.value, attribute_template.value)
- self.assertIsNotNone(attribute.name)
- self.assertIsNot(attribute.name, attribute_template.name)
+ def test_update(self):
+ """Tests the `update` helper"""
+ attribute_template = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+ attribute = ir_data.Attribute(is_default=True)
+ ir_data_utils.update(attribute, attribute_template)
+ self.assertIsNotNone(attribute.value)
+ self.assertIsNot(attribute.value, attribute_template.value)
+ self.assertIsNotNone(attribute.name)
+ self.assertIsNot(attribute.name, attribute_template.name)
- # Value not present in template should be untouched
- self.assertTrue(attribute.is_default)
+ # Value not present in template should be untouched
+ self.assertTrue(attribute.is_default)
class IrDataBuilderTest(unittest.TestCase):
- """Tests for IrDataBuilder"""
+ """Tests for IrDataBuilder"""
- def test_ir_data_builder(self):
- """Tests that basic builder chains work"""
- # We start with an empty type
- type_def = ir_data.TypeDefinition()
- self.assertFalse(type_def.HasField("name"))
- self.assertIsNone(type_def.name)
+ def test_ir_data_builder(self):
+ """Tests that basic builder chains work"""
+ # We start with an empty type
+ type_def = ir_data.TypeDefinition()
+ self.assertFalse(type_def.HasField("name"))
+ self.assertIsNone(type_def.name)
- # Now setup a builder
- builder = ir_data_utils.builder(type_def)
+ # Now setup a builder
+ builder = ir_data_utils.builder(type_def)
- # Assign to a sub-child
- builder.name.name = ir_data.Word(text="phil")
+ # Assign to a sub-child
+ builder.name.name = ir_data.Word(text="phil")
- # Verify the wrapped struct is updated
- self.assertIsNotNone(type_def.name)
- self.assertIsNotNone(type_def.name.name)
- self.assertIsNotNone(type_def.name.name.text)
- self.assertEqual(type_def.name.name.text, "phil")
+ # Verify the wrapped struct is updated
+ self.assertIsNotNone(type_def.name)
+ self.assertIsNotNone(type_def.name.name)
+ self.assertIsNotNone(type_def.name.name.text)
+ self.assertEqual(type_def.name.name.text, "phil")
- def test_ir_data_builder_bad_field(self):
- """Tests accessing an undefined field name fails"""
- type_def = ir_data.TypeDefinition()
- builder = ir_data_utils.builder(type_def)
- self.assertRaises(AttributeError, lambda: builder.foo)
- # Make sure it's not set on our IR data class either
- self.assertRaises(AttributeError, getattr, type_def, "foo")
+ def test_ir_data_builder_bad_field(self):
+ """Tests accessing an undefined field name fails"""
+ type_def = ir_data.TypeDefinition()
+ builder = ir_data_utils.builder(type_def)
+ self.assertRaises(AttributeError, lambda: builder.foo)
+ # Make sure it's not set on our IR data class either
+ self.assertRaises(AttributeError, getattr, type_def, "foo")
- def test_ir_data_builder_sequence(self):
- """Tests that sequences are properly wrapped"""
- # We start with an empty type
- type_def = ir_data.TypeDefinition()
- self.assertTrue(type_def.HasField("attribute"))
- self.assertEqual(len(type_def.attribute), 0)
+ def test_ir_data_builder_sequence(self):
+ """Tests that sequences are properly wrapped"""
+ # We start with an empty type
+ type_def = ir_data.TypeDefinition()
+ self.assertTrue(type_def.HasField("attribute"))
+ self.assertEqual(len(type_def.attribute), 0)
- # Now setup a builder
- builder = ir_data_utils.builder(type_def)
+ # Now setup a builder
+ builder = ir_data_utils.builder(type_def)
- # Assign to a sequence
- attribute = ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil"),
- )
-
- builder.attribute.append(attribute)
- self.assertEqual(builder.attribute, [attribute])
- self.assertTrue(type_def.HasField("attribute"))
- self.assertEqual(len(type_def.attribute), 1)
- self.assertEqual(type_def.attribute[0], attribute)
-
- # Lets make it longer and then try iterating
- builder.attribute.append(attribute)
- self.assertEqual(len(type_def.attribute), 2)
- for attr in builder.attribute:
- # Modify the attributes
- attr.name.text = "bob"
-
- # Make sure we can build up auto-default entries from a sequence item
- builder.attribute.append(ir_data.Attribute())
- builder.attribute[-1].value.expression = ir_data.Expression()
- builder.attribute[-1].name.text = "bob"
-
- # Create an attribute to compare against
- new_attribute = ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="bob"),
- )
-
- self.assertEqual(len(type_def.attribute), 3)
- for attr in type_def.attribute:
- self.assertEqual(attr, new_attribute)
-
- # Make sure the list type is a CopyValuesList
- self.assertIsInstance(
- type_def.attribute,
- ir_data_fields.CopyValuesList,
- f"Instance is: {type(type_def.attribute)}",
- )
-
- def test_copy_from(self) -> None:
- """Tests that `CopyFrom` works."""
- location = ir_data.Location(
- start=ir_data.Position(line=1, column=1),
- end=ir_data.Position(line=1, column=2),
- )
- expression_ir = ir_data.Expression(source_location=location)
- template: ir_data.Expression = expression_parser.parse("x + y")
- expression = ir_data_utils.builder(expression_ir)
- expression.CopyFrom(template)
- self.assertIsNotNone(expression_ir.function)
- self.assertIsInstance(expression.function, ir_data_utils._IrDataBuilder)
- self.assertIsInstance(
- expression.function.args, ir_data_utils._IrDataSequenceBuilder
- )
- self.assertTrue(expression_ir.function.args)
-
- def test_copy_from_list(self):
- specs = ir_data_utils.field_specs(ir_data.Function)
- args_spec = specs["args"]
- self.assertTrue(args_spec.is_dataclass)
- template: ir_data.Expression = expression_parser.parse("x + y")
- self.assertIsNotNone(template)
- self.assertIsInstance(template, ir_data.Expression)
- self.assertIsInstance(template.function, ir_data.Function)
- self.assertIsInstance(template.function.args, ir_data_fields.CopyValuesList)
-
- location = ir_data.Location(
- start=ir_data.Position(line=1, column=1),
- end=ir_data.Position(line=1, column=2),
- )
- expression_ir = ir_data.Expression(source_location=location)
- self.assertIsInstance(expression_ir, ir_data.Expression)
- self.assertIsNone(expression_ir.function)
-
- expression_builder = ir_data_utils.builder(expression_ir)
- self.assertIsInstance(expression_builder, ir_data_utils._IrDataBuilder)
- expression_builder.CopyFrom(template)
- self.assertIsNotNone(expression_ir.function)
- self.assertIsInstance(expression_ir.function, ir_data.Function)
- self.assertIsNotNone(expression_ir.function.args)
- self.assertIsInstance(
- expression_ir.function.args, ir_data_fields.CopyValuesList
- )
-
- self.assertIsInstance(expression_builder, ir_data_utils._IrDataBuilder)
- self.assertIsInstance(
- expression_builder.function, ir_data_utils._IrDataBuilder
- )
- self.assertIsInstance(
- expression_builder.function.args, ir_data_utils._IrDataSequenceBuilder
- )
-
- def test_ir_data_builder_sequence_scalar(self):
- """Tests that sequences of scalars function properly"""
- # We start with an empty type
- structure = ir_data.Structure()
-
- # Now setup a builder
- builder = ir_data_utils.builder(structure)
-
- # Assign to a scalar sequence
- builder.fields_in_dependency_order.append(12)
- builder.fields_in_dependency_order.append(11)
-
- self.assertTrue(structure.HasField("fields_in_dependency_order"))
- self.assertEqual(len(structure.fields_in_dependency_order), 2)
- self.assertEqual(structure.fields_in_dependency_order[0], 12)
- self.assertEqual(structure.fields_in_dependency_order[1], 11)
- self.assertEqual(builder.fields_in_dependency_order, [12, 11])
-
- new_structure = ir_data.Structure(fields_in_dependency_order=[12, 11])
- self.assertEqual(structure, new_structure)
-
- def test_ir_data_builder_oneof(self):
- value = ir_data.AttributeValue(
- expression=ir_data.Expression(
- boolean_constant=ir_data.BooleanConstant()
+ # Assign to a sequence
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
)
- )
- builder = ir_data_utils.builder(value)
- self.assertTrue(builder.HasField("expression"))
- self.assertFalse(builder.expression.boolean_constant.value)
- builder.expression.boolean_constant.value = True
- self.assertTrue(builder.expression.boolean_constant.value)
- self.assertTrue(value.expression.boolean_constant.value)
- bool_constant = value.expression.boolean_constant
- self.assertIsInstance(bool_constant, ir_data.BooleanConstant)
+ builder.attribute.append(attribute)
+ self.assertEqual(builder.attribute, [attribute])
+ self.assertTrue(type_def.HasField("attribute"))
+ self.assertEqual(len(type_def.attribute), 1)
+ self.assertEqual(type_def.attribute[0], attribute)
+
+ # Lets make it longer and then try iterating
+ builder.attribute.append(attribute)
+ self.assertEqual(len(type_def.attribute), 2)
+ for attr in builder.attribute:
+ # Modify the attributes
+ attr.name.text = "bob"
+
+ # Make sure we can build up auto-default entries from a sequence item
+ builder.attribute.append(ir_data.Attribute())
+ builder.attribute[-1].value.expression = ir_data.Expression()
+ builder.attribute[-1].name.text = "bob"
+
+ # Create an attribute to compare against
+ new_attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="bob"),
+ )
+
+ self.assertEqual(len(type_def.attribute), 3)
+ for attr in type_def.attribute:
+ self.assertEqual(attr, new_attribute)
+
+ # Make sure the list type is a CopyValuesList
+ self.assertIsInstance(
+ type_def.attribute,
+ ir_data_fields.CopyValuesList,
+ f"Instance is: {type(type_def.attribute)}",
+ )
+
+ def test_copy_from(self) -> None:
+ """Tests that `CopyFrom` works."""
+ location = ir_data.Location(
+ start=ir_data.Position(line=1, column=1),
+ end=ir_data.Position(line=1, column=2),
+ )
+ expression_ir = ir_data.Expression(source_location=location)
+ template: ir_data.Expression = expression_parser.parse("x + y")
+ expression = ir_data_utils.builder(expression_ir)
+ expression.CopyFrom(template)
+ self.assertIsNotNone(expression_ir.function)
+ self.assertIsInstance(expression.function, ir_data_utils._IrDataBuilder)
+ self.assertIsInstance(
+ expression.function.args, ir_data_utils._IrDataSequenceBuilder
+ )
+ self.assertTrue(expression_ir.function.args)
+
+ def test_copy_from_list(self):
+ specs = ir_data_utils.field_specs(ir_data.Function)
+ args_spec = specs["args"]
+ self.assertTrue(args_spec.is_dataclass)
+ template: ir_data.Expression = expression_parser.parse("x + y")
+ self.assertIsNotNone(template)
+ self.assertIsInstance(template, ir_data.Expression)
+ self.assertIsInstance(template.function, ir_data.Function)
+ self.assertIsInstance(template.function.args, ir_data_fields.CopyValuesList)
+
+ location = ir_data.Location(
+ start=ir_data.Position(line=1, column=1),
+ end=ir_data.Position(line=1, column=2),
+ )
+ expression_ir = ir_data.Expression(source_location=location)
+ self.assertIsInstance(expression_ir, ir_data.Expression)
+ self.assertIsNone(expression_ir.function)
+
+ expression_builder = ir_data_utils.builder(expression_ir)
+ self.assertIsInstance(expression_builder, ir_data_utils._IrDataBuilder)
+ expression_builder.CopyFrom(template)
+ self.assertIsNotNone(expression_ir.function)
+ self.assertIsInstance(expression_ir.function, ir_data.Function)
+ self.assertIsNotNone(expression_ir.function.args)
+ self.assertIsInstance(
+ expression_ir.function.args, ir_data_fields.CopyValuesList
+ )
+
+ self.assertIsInstance(expression_builder, ir_data_utils._IrDataBuilder)
+ self.assertIsInstance(expression_builder.function, ir_data_utils._IrDataBuilder)
+ self.assertIsInstance(
+ expression_builder.function.args, ir_data_utils._IrDataSequenceBuilder
+ )
+
+ def test_ir_data_builder_sequence_scalar(self):
+ """Tests that sequences of scalars function properly"""
+ # We start with an empty type
+ structure = ir_data.Structure()
+
+ # Now setup a builder
+ builder = ir_data_utils.builder(structure)
+
+ # Assign to a scalar sequence
+ builder.fields_in_dependency_order.append(12)
+ builder.fields_in_dependency_order.append(11)
+
+ self.assertTrue(structure.HasField("fields_in_dependency_order"))
+ self.assertEqual(len(structure.fields_in_dependency_order), 2)
+ self.assertEqual(structure.fields_in_dependency_order[0], 12)
+ self.assertEqual(structure.fields_in_dependency_order[1], 11)
+ self.assertEqual(builder.fields_in_dependency_order, [12, 11])
+
+ new_structure = ir_data.Structure(fields_in_dependency_order=[12, 11])
+ self.assertEqual(structure, new_structure)
+
+ def test_ir_data_builder_oneof(self):
+ value = ir_data.AttributeValue(
+ expression=ir_data.Expression(boolean_constant=ir_data.BooleanConstant())
+ )
+ builder = ir_data_utils.builder(value)
+ self.assertTrue(builder.HasField("expression"))
+ self.assertFalse(builder.expression.boolean_constant.value)
+ builder.expression.boolean_constant.value = True
+ self.assertTrue(builder.expression.boolean_constant.value)
+ self.assertTrue(value.expression.boolean_constant.value)
+
+ bool_constant = value.expression.boolean_constant
+ self.assertIsInstance(bool_constant, ir_data.BooleanConstant)
class IrDataSerializerTest(unittest.TestCase):
- """Tests for IrDataSerializer"""
+ """Tests for IrDataSerializer"""
- def test_ir_data_serializer_to_dict(self):
- """Tests serialization with `IrDataSerializer.to_dict` with default settings"""
- attribute = ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil"),
- )
+ def test_ir_data_serializer_to_dict(self):
+ """Tests serialization with `IrDataSerializer.to_dict` with default settings"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
- serializer = ir_data_utils.IrDataSerializer(attribute)
- raw_dict = serializer.to_dict()
- expected = {
- "name": {"text": "phil", "source_location": None},
- "value": {
- "expression": {
- "constant": None,
- "constant_reference": None,
- "function": None,
- "field_reference": None,
- "boolean_constant": None,
- "builtin_reference": None,
- "type": None,
+ serializer = ir_data_utils.IrDataSerializer(attribute)
+ raw_dict = serializer.to_dict()
+ expected = {
+ "name": {"text": "phil", "source_location": None},
+ "value": {
+ "expression": {
+ "constant": None,
+ "constant_reference": None,
+ "function": None,
+ "field_reference": None,
+ "boolean_constant": None,
+ "builtin_reference": None,
+ "type": None,
+ "source_location": None,
+ },
+ "string_constant": None,
"source_location": None,
},
- "string_constant": None,
+ "back_end": None,
+ "is_default": None,
"source_location": None,
- },
- "back_end": None,
- "is_default": None,
- "source_location": None,
- }
- self.assertDictEqual(raw_dict, expected)
+ }
+ self.assertDictEqual(raw_dict, expected)
- def test_ir_data_serializer_to_dict_exclude_none(self):
- """Tests serialization with `IrDataSerializer.to_dict` when excluding None values"""
- attribute = ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil"),
- )
- serializer = ir_data_utils.IrDataSerializer(attribute)
- raw_dict = serializer.to_dict(exclude_none=True)
- expected = {"name": {"text": "phil"}, "value": {"expression": {}}}
- self.assertDictEqual(raw_dict, expected)
+ def test_ir_data_serializer_to_dict_exclude_none(self):
+ """Tests serialization with `IrDataSerializer.to_dict` when excluding None values"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+ serializer = ir_data_utils.IrDataSerializer(attribute)
+ raw_dict = serializer.to_dict(exclude_none=True)
+ expected = {"name": {"text": "phil"}, "value": {"expression": {}}}
+ self.assertDictEqual(raw_dict, expected)
- def test_ir_data_serializer_to_dict_enum(self):
- """Tests that serialization of `enum.Enum` values works properly"""
- type_def = ir_data.TypeDefinition(
- addressable_unit=ir_data.AddressableUnit.BYTE
- )
- serializer = ir_data_utils.IrDataSerializer(type_def)
- raw_dict = serializer.to_dict(exclude_none=True)
- expected = {"addressable_unit": ir_data.AddressableUnit.BYTE}
- self.assertDictEqual(raw_dict, expected)
+ def test_ir_data_serializer_to_dict_enum(self):
+ """Tests that serialization of `enum.Enum` values works properly"""
+ type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE)
+ serializer = ir_data_utils.IrDataSerializer(type_def)
+ raw_dict = serializer.to_dict(exclude_none=True)
+ expected = {"addressable_unit": ir_data.AddressableUnit.BYTE}
+ self.assertDictEqual(raw_dict, expected)
- def test_ir_data_serializer_from_dict(self):
- """Tests deserializing IR data from a serialized dict"""
- attribute = ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil"),
- )
- serializer = ir_data_utils.IrDataSerializer(attribute)
- raw_dict = serializer.to_dict(exclude_none=False)
- new_attribute = serializer.from_dict(ir_data.Attribute, raw_dict)
- self.assertEqual(attribute, new_attribute)
+ def test_ir_data_serializer_from_dict(self):
+ """Tests deserializing IR data from a serialized dict"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+ serializer = ir_data_utils.IrDataSerializer(attribute)
+ raw_dict = serializer.to_dict(exclude_none=False)
+ new_attribute = serializer.from_dict(ir_data.Attribute, raw_dict)
+ self.assertEqual(attribute, new_attribute)
- def test_ir_data_serializer_from_dict_enum(self):
- """Tests that deserializing `enum.Enum` values works properly"""
- type_def = ir_data.TypeDefinition(
- addressable_unit=ir_data.AddressableUnit.BYTE
- )
+ def test_ir_data_serializer_from_dict_enum(self):
+ """Tests that deserializing `enum.Enum` values works properly"""
+ type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE)
- serializer = ir_data_utils.IrDataSerializer(type_def)
- raw_dict = serializer.to_dict(exclude_none=False)
- new_type_def = serializer.from_dict(ir_data.TypeDefinition, raw_dict)
- self.assertEqual(type_def, new_type_def)
+ serializer = ir_data_utils.IrDataSerializer(type_def)
+ raw_dict = serializer.to_dict(exclude_none=False)
+ new_type_def = serializer.from_dict(ir_data.TypeDefinition, raw_dict)
+ self.assertEqual(type_def, new_type_def)
- def test_ir_data_serializer_from_dict_enum_is_str(self):
- """Tests that deserializing `enum.Enum` values works properly when string constant is used"""
- type_def = ir_data.TypeDefinition(
- addressable_unit=ir_data.AddressableUnit.BYTE
- )
- raw_dict = {"addressable_unit": "BYTE"}
- serializer = ir_data_utils.IrDataSerializer(type_def)
- new_type_def = serializer.from_dict(ir_data.TypeDefinition, raw_dict)
- self.assertEqual(type_def, new_type_def)
+ def test_ir_data_serializer_from_dict_enum_is_str(self):
+ """Tests that deserializing `enum.Enum` values works properly when string constant is used"""
+ type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE)
+ raw_dict = {"addressable_unit": "BYTE"}
+ serializer = ir_data_utils.IrDataSerializer(type_def)
+ new_type_def = serializer.from_dict(ir_data.TypeDefinition, raw_dict)
+ self.assertEqual(type_def, new_type_def)
- def test_ir_data_serializer_from_dict_exclude_none(self):
- """Tests that deserializing from a dict that excluded None values works properly"""
- attribute = ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil"),
- )
+ def test_ir_data_serializer_from_dict_exclude_none(self):
+ """Tests that deserializing from a dict that excluded None values works properly"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
- serializer = ir_data_utils.IrDataSerializer(attribute)
- raw_dict = serializer.to_dict(exclude_none=True)
- new_attribute = ir_data_utils.IrDataSerializer.from_dict(
- ir_data.Attribute, raw_dict
- )
- self.assertEqual(attribute, new_attribute)
+ serializer = ir_data_utils.IrDataSerializer(attribute)
+ raw_dict = serializer.to_dict(exclude_none=True)
+ new_attribute = ir_data_utils.IrDataSerializer.from_dict(
+ ir_data.Attribute, raw_dict
+ )
+ self.assertEqual(attribute, new_attribute)
- def test_from_dict_list(self):
- function_args = [
- {
- "constant": {
- "value": "0",
+ def test_from_dict_list(self):
+ function_args = [
+ {
+ "constant": {
+ "value": "0",
+ "source_location": {
+ "start": {"line": 421, "column": 3},
+ "end": {"line": 421, "column": 4},
+ "is_synthetic": False,
+ },
+ },
+ "type": {
+ "integer": {
+ "modulus": "infinity",
+ "modular_value": "0",
+ "minimum_value": "0",
+ "maximum_value": "0",
+ }
+ },
"source_location": {
"start": {"line": 421, "column": 3},
"end": {"line": 421, "column": 4},
"is_synthetic": False,
},
},
- "type": {
- "integer": {
- "modulus": "infinity",
- "modular_value": "0",
- "minimum_value": "0",
- "maximum_value": "0",
- }
- },
- "source_location": {
- "start": {"line": 421, "column": 3},
- "end": {"line": 421, "column": 4},
- "is_synthetic": False,
- },
- },
- {
- "constant": {
- "value": "1",
+ {
+ "constant": {
+ "value": "1",
+ "source_location": {
+ "start": {"line": 421, "column": 11},
+ "end": {"line": 421, "column": 12},
+ "is_synthetic": False,
+ },
+ },
+ "type": {
+ "integer": {
+ "modulus": "infinity",
+ "modular_value": "1",
+ "minimum_value": "1",
+ "maximum_value": "1",
+ }
+ },
"source_location": {
"start": {"line": 421, "column": 11},
"end": {"line": 421, "column": 12},
"is_synthetic": False,
},
},
- "type": {
- "integer": {
- "modulus": "infinity",
- "modular_value": "1",
- "minimum_value": "1",
- "maximum_value": "1",
- }
- },
- "source_location": {
- "start": {"line": 421, "column": 11},
- "end": {"line": 421, "column": 12},
- "is_synthetic": False,
- },
- },
- ]
- function_data = {"args": function_args}
- func = ir_data_utils.IrDataSerializer.from_dict(
- ir_data.Function, function_data
- )
- self.assertIsNotNone(func)
+ ]
+ function_data = {"args": function_args}
+ func = ir_data_utils.IrDataSerializer.from_dict(ir_data.Function, function_data)
+ self.assertIsNotNone(func)
- def test_ir_data_serializer_copy_from_dict(self):
- """Tests that updating an IR data struct from a dict works properly"""
- attribute = ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil"),
- )
- serializer = ir_data_utils.IrDataSerializer(attribute)
- raw_dict = serializer.to_dict(exclude_none=False)
+ def test_ir_data_serializer_copy_from_dict(self):
+ """Tests that updating an IR data struct from a dict works properly"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+ serializer = ir_data_utils.IrDataSerializer(attribute)
+ raw_dict = serializer.to_dict(exclude_none=False)
- new_attribute = ir_data.Attribute()
- new_serializer = ir_data_utils.IrDataSerializer(new_attribute)
- new_serializer.copy_from_dict(raw_dict)
- self.assertEqual(attribute, new_attribute)
+ new_attribute = ir_data.Attribute()
+ new_serializer = ir_data_utils.IrDataSerializer(new_attribute)
+ new_serializer.copy_from_dict(raw_dict)
+ self.assertEqual(attribute, new_attribute)
class ReadOnlyFieldCheckerTest(unittest.TestCase):
- """Tests the ReadOnlyFieldChecker"""
+ """Tests the ReadOnlyFieldChecker"""
- def test_basic_wrapper(self):
- """Tests basic field checker actions"""
- union = ClassWithTwoUnions(
- opaque=Opaque(), boolean=True, non_union_field=10
- )
- field_checker = ir_data_utils.reader(union)
+ def test_basic_wrapper(self):
+ """Tests basic field checker actions"""
+ union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10)
+ field_checker = ir_data_utils.reader(union)
- # All accesses should return a wrapper object
- self.assertIsNotNone(field_checker.opaque)
- self.assertIsNotNone(field_checker.integer)
- self.assertIsNotNone(field_checker.boolean)
- self.assertIsNotNone(field_checker.enumeration)
- self.assertIsNotNone(field_checker.non_union_field)
- # Scalar field should pass through
- self.assertEqual(field_checker.non_union_field, 10)
+ # All accesses should return a wrapper object
+ self.assertIsNotNone(field_checker.opaque)
+ self.assertIsNotNone(field_checker.integer)
+ self.assertIsNotNone(field_checker.boolean)
+ self.assertIsNotNone(field_checker.enumeration)
+ self.assertIsNotNone(field_checker.non_union_field)
+ # Scalar field should pass through
+ self.assertEqual(field_checker.non_union_field, 10)
- # Make sure HasField works
- self.assertTrue(field_checker.HasField("opaque"))
- self.assertFalse(field_checker.HasField("integer"))
- self.assertTrue(field_checker.HasField("boolean"))
- self.assertFalse(field_checker.HasField("enumeration"))
- self.assertTrue(field_checker.HasField("non_union_field"))
+ # Make sure HasField works
+ self.assertTrue(field_checker.HasField("opaque"))
+ self.assertFalse(field_checker.HasField("integer"))
+ self.assertTrue(field_checker.HasField("boolean"))
+ self.assertFalse(field_checker.HasField("enumeration"))
+ self.assertTrue(field_checker.HasField("non_union_field"))
- def test_construct_from_field_checker(self):
- """Tests that constructing from another field checker works"""
- union = ClassWithTwoUnions(
- opaque=Opaque(), boolean=True, non_union_field=10
- )
- field_checker_orig = ir_data_utils.reader(union)
- field_checker = ir_data_utils.reader(field_checker_orig)
- self.assertIsNotNone(field_checker)
- self.assertEqual(field_checker.ir_or_spec, union)
+ def test_construct_from_field_checker(self):
+ """Tests that constructing from another field checker works"""
+ union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10)
+ field_checker_orig = ir_data_utils.reader(union)
+ field_checker = ir_data_utils.reader(field_checker_orig)
+ self.assertIsNotNone(field_checker)
+ self.assertEqual(field_checker.ir_or_spec, union)
- # All accesses should return a wrapper object
- self.assertIsNotNone(field_checker.opaque)
- self.assertIsNotNone(field_checker.integer)
- self.assertIsNotNone(field_checker.boolean)
- self.assertIsNotNone(field_checker.enumeration)
- self.assertIsNotNone(field_checker.non_union_field)
- # Scalar field should pass through
- self.assertEqual(field_checker.non_union_field, 10)
+ # All accesses should return a wrapper object
+ self.assertIsNotNone(field_checker.opaque)
+ self.assertIsNotNone(field_checker.integer)
+ self.assertIsNotNone(field_checker.boolean)
+ self.assertIsNotNone(field_checker.enumeration)
+ self.assertIsNotNone(field_checker.non_union_field)
+ # Scalar field should pass through
+ self.assertEqual(field_checker.non_union_field, 10)
- # Make sure HasField works
- self.assertTrue(field_checker.HasField("opaque"))
- self.assertFalse(field_checker.HasField("integer"))
- self.assertTrue(field_checker.HasField("boolean"))
- self.assertFalse(field_checker.HasField("enumeration"))
- self.assertTrue(field_checker.HasField("non_union_field"))
+ # Make sure HasField works
+ self.assertTrue(field_checker.HasField("opaque"))
+ self.assertFalse(field_checker.HasField("integer"))
+ self.assertTrue(field_checker.HasField("boolean"))
+ self.assertFalse(field_checker.HasField("enumeration"))
+ self.assertTrue(field_checker.HasField("non_union_field"))
- def test_read_only(self) -> None:
- """Tests that the read only wrapper really is read only"""
- union = ClassWithTwoUnions(
- opaque=Opaque(), boolean=True, non_union_field=10
- )
- field_checker = ir_data_utils.reader(union)
+ def test_read_only(self) -> None:
+ """Tests that the read only wrapper really is read only"""
+ union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10)
+ field_checker = ir_data_utils.reader(union)
- def set_field():
- field_checker.opaque = None
+ def set_field():
+ field_checker.opaque = None
- self.assertRaises(AttributeError, set_field)
+ self.assertRaises(AttributeError, set_field)
ir_data_fields.cache_message_specs(
- sys.modules[ReadOnlyFieldCheckerTest.__module__], ir_data.Message)
+ sys.modules[ReadOnlyFieldCheckerTest.__module__], ir_data.Message
+)
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/util/ir_util.py b/compiler/util/ir_util.py
index 603c0c0..2d8763b 100644
--- a/compiler/util/ir_util.py
+++ b/compiler/util/ir_util.py
@@ -24,376 +24,384 @@
def get_attribute(attribute_list, name):
- """Finds name in attribute_list and returns a AttributeValue or None."""
- if not attribute_list:
- return None
- attribute_value = None
- for attr in attribute_list:
- if attr.name.text == name and not attr.is_default:
- assert attribute_value is None, 'Duplicate attribute "{}".'.format(name)
- attribute_value = attr.value
- return attribute_value
+ """Finds name in attribute_list and returns a AttributeValue or None."""
+ if not attribute_list:
+ return None
+ attribute_value = None
+ for attr in attribute_list:
+ if attr.name.text == name and not attr.is_default:
+ assert attribute_value is None, 'Duplicate attribute "{}".'.format(name)
+ attribute_value = attr.value
+ return attribute_value
def get_boolean_attribute(attribute_list, name, default_value=None):
- """Returns the boolean value of an attribute, if any, or default_value.
+ """Returns the boolean value of an attribute, if any, or default_value.
- Arguments:
- attribute_list: A list of attributes to search.
- name: The name of the desired attribute.
- default_value: A value to return if name is not found in attribute_list,
- or the attribute does not have a boolean value.
+ Arguments:
+ attribute_list: A list of attributes to search.
+ name: The name of the desired attribute.
+ default_value: A value to return if name is not found in attribute_list,
+ or the attribute does not have a boolean value.
- Returns:
- The boolean value of the requested attribute, or default_value if the
- requested attribute is not found or has a non-boolean value.
- """
- attribute_value = get_attribute(attribute_list, name)
- if (not attribute_value or
- not attribute_value.expression.HasField("boolean_constant")):
- return default_value
- return attribute_value.expression.boolean_constant.value
+ Returns:
+ The boolean value of the requested attribute, or default_value if the
+ requested attribute is not found or has a non-boolean value.
+ """
+ attribute_value = get_attribute(attribute_list, name)
+ if not attribute_value or not attribute_value.expression.HasField(
+ "boolean_constant"
+ ):
+ return default_value
+ return attribute_value.expression.boolean_constant.value
def get_integer_attribute(attribute_list, name, default_value=None):
- """Returns the integer value of an attribute, if any, or default_value.
+ """Returns the integer value of an attribute, if any, or default_value.
- Arguments:
- attribute_list: A list of attributes to search.
- name: The name of the desired attribute.
- default_value: A value to return if name is not found in attribute_list,
- or the attribute does not have an integer value.
+ Arguments:
+ attribute_list: A list of attributes to search.
+ name: The name of the desired attribute.
+ default_value: A value to return if name is not found in attribute_list,
+ or the attribute does not have an integer value.
- Returns:
- The integer value of the requested attribute, or default_value if the
- requested attribute is not found or has a non-integer value.
- """
- attribute_value = get_attribute(attribute_list, name)
- if (not attribute_value or
- attribute_value.expression.type.WhichOneof("type") != "integer" or
- not is_constant(attribute_value.expression)):
- return default_value
- return constant_value(attribute_value.expression)
+ Returns:
+ The integer value of the requested attribute, or default_value if the
+ requested attribute is not found or has a non-integer value.
+ """
+ attribute_value = get_attribute(attribute_list, name)
+ if (
+ not attribute_value
+ or attribute_value.expression.type.WhichOneof("type") != "integer"
+ or not is_constant(attribute_value.expression)
+ ):
+ return default_value
+ return constant_value(attribute_value.expression)
def is_constant(expression, bindings=None):
- return constant_value(expression, bindings) is not None
+ return constant_value(expression, bindings) is not None
def is_constant_type(expression_type):
- """Returns True if expression_type is inhabited by a single value."""
- expression_type = ir_data_utils.reader(expression_type)
- return (expression_type.integer.modulus == "infinity" or
- expression_type.boolean.HasField("value") or
- expression_type.enumeration.HasField("value"))
+ """Returns True if expression_type is inhabited by a single value."""
+ expression_type = ir_data_utils.reader(expression_type)
+ return (
+ expression_type.integer.modulus == "infinity"
+ or expression_type.boolean.HasField("value")
+ or expression_type.enumeration.HasField("value")
+ )
def constant_value(expression, bindings=None):
- """Evaluates expression with the given bindings."""
- if expression is None:
- return None
- expression = ir_data_utils.reader(expression)
- if expression.WhichOneof("expression") == "constant":
- return int(expression.constant.value or 0)
- elif expression.WhichOneof("expression") == "constant_reference":
- # We can't look up the constant reference without the IR, but by the time
- # constant_value is called, the actual values should have been propagated to
- # the type information.
- if expression.type.WhichOneof("type") == "integer":
- assert expression.type.integer.modulus == "infinity"
- return int(expression.type.integer.modular_value)
- elif expression.type.WhichOneof("type") == "boolean":
- assert expression.type.boolean.HasField("value")
- return expression.type.boolean.value
- elif expression.type.WhichOneof("type") == "enumeration":
- assert expression.type.enumeration.HasField("value")
- return int(expression.type.enumeration.value)
+ """Evaluates expression with the given bindings."""
+ if expression is None:
+ return None
+ expression = ir_data_utils.reader(expression)
+ if expression.WhichOneof("expression") == "constant":
+ return int(expression.constant.value or 0)
+ elif expression.WhichOneof("expression") == "constant_reference":
+ # We can't look up the constant reference without the IR, but by the time
+ # constant_value is called, the actual values should have been propagated to
+ # the type information.
+ if expression.type.WhichOneof("type") == "integer":
+ assert expression.type.integer.modulus == "infinity"
+ return int(expression.type.integer.modular_value)
+ elif expression.type.WhichOneof("type") == "boolean":
+ assert expression.type.boolean.HasField("value")
+ return expression.type.boolean.value
+ elif expression.type.WhichOneof("type") == "enumeration":
+ assert expression.type.enumeration.HasField("value")
+ return int(expression.type.enumeration.value)
+ else:
+ assert False, "Unexpected expression type {}".format(
+ expression.type.WhichOneof("type")
+ )
+ elif expression.WhichOneof("expression") == "function":
+ return _constant_value_of_function(expression.function, bindings)
+ elif expression.WhichOneof("expression") == "field_reference":
+ return None
+ elif expression.WhichOneof("expression") == "boolean_constant":
+ return expression.boolean_constant.value
+ elif expression.WhichOneof("expression") == "builtin_reference":
+ name = expression.builtin_reference.canonical_name.object_path[0]
+ if bindings and name in bindings:
+ return bindings[name]
+ else:
+ return None
+ elif expression.WhichOneof("expression") is None:
+ return None
else:
- assert False, "Unexpected expression type {}".format(
- expression.type.WhichOneof("type"))
- elif expression.WhichOneof("expression") == "function":
- return _constant_value_of_function(expression.function, bindings)
- elif expression.WhichOneof("expression") == "field_reference":
- return None
- elif expression.WhichOneof("expression") == "boolean_constant":
- return expression.boolean_constant.value
- elif expression.WhichOneof("expression") == "builtin_reference":
- name = expression.builtin_reference.canonical_name.object_path[0]
- if bindings and name in bindings:
- return bindings[name]
- else:
- return None
- elif expression.WhichOneof("expression") is None:
- return None
- else:
- assert False, "Unexpected expression kind {}".format(
- expression.WhichOneof("expression"))
+ assert False, "Unexpected expression kind {}".format(
+ expression.WhichOneof("expression")
+ )
def _constant_value_of_function(function, bindings):
- """Returns the constant value of evaluating `function`, or None."""
- values = [constant_value(arg, bindings) for arg in function.args]
- # Expressions like `$is_statically_sized && 1 <= $static_size_in_bits <= 64`
- # should return False, not None, if `$is_statically_sized` is false, even
- # though `$static_size_in_bits` is unknown.
- #
- # The easiest way to allow this is to use a three-way logic chart for each;
- # specifically:
- #
- # AND: True False Unknown
- # +--------------------------
- # True | True False Unknown
- # False | False False False
- # Unknown | Unknown False Unknown
- #
- # OR: True False Unknown
- # +--------------------------
- # True | True True True
- # False | True False Unknown
- # Unknown | True Unknown Unknown
- #
- # This raises the question of just how many constant-from-nonconstant
- # expressions Emboss should support. There are, after all, a vast number of
- # constant expression patterns built from non-constant subexpressions, such as
- # `0 * X` or `X == X` or `3 * X == X + X + X`. I (bolms@) am not implementing
- # any further special cases because I do not see any practical use for them.
- if function.function == ir_data.FunctionMapping.UNKNOWN:
- return None
- if function.function == ir_data.FunctionMapping.AND:
- if any(value is False for value in values):
- return False
- elif any(value is None for value in values):
- return None
- else:
- return True
- elif function.function == ir_data.FunctionMapping.OR:
- if any(value is True for value in values):
- return True
- elif any(value is None for value in values):
- return None
- else:
- return False
- elif function.function == ir_data.FunctionMapping.CHOICE:
- if values[0] is None:
- return None
- else:
- return values[1] if values[0] else values[2]
- # Other than the logical operators and choice operator, the result of any
- # function on an unknown value is, itself, considered unknown.
- if any(value is None for value in values):
- return None
- functions = {
- ir_data.FunctionMapping.ADDITION: operator.add,
- ir_data.FunctionMapping.SUBTRACTION: operator.sub,
- ir_data.FunctionMapping.MULTIPLICATION: operator.mul,
- ir_data.FunctionMapping.EQUALITY: operator.eq,
- ir_data.FunctionMapping.INEQUALITY: operator.ne,
- ir_data.FunctionMapping.LESS: operator.lt,
- ir_data.FunctionMapping.LESS_OR_EQUAL: operator.le,
- ir_data.FunctionMapping.GREATER: operator.gt,
- ir_data.FunctionMapping.GREATER_OR_EQUAL: operator.ge,
- # Python's max([1, 2]) == 2; max(1, 2) == 2; max([1]) == 1; but max(1)
- # throws a TypeError ("'int' object is not iterable").
- ir_data.FunctionMapping.MAXIMUM: lambda *x: max(x),
- }
- return functions[function.function](*values)
+ """Returns the constant value of evaluating `function`, or None."""
+ values = [constant_value(arg, bindings) for arg in function.args]
+ # Expressions like `$is_statically_sized && 1 <= $static_size_in_bits <= 64`
+ # should return False, not None, if `$is_statically_sized` is false, even
+ # though `$static_size_in_bits` is unknown.
+ #
+ # The easiest way to allow this is to use a three-way logic chart for each;
+ # specifically:
+ #
+ # AND: True False Unknown
+ # +--------------------------
+ # True | True False Unknown
+ # False | False False False
+ # Unknown | Unknown False Unknown
+ #
+ # OR: True False Unknown
+ # +--------------------------
+ # True | True True True
+ # False | True False Unknown
+ # Unknown | True Unknown Unknown
+ #
+ # This raises the question of just how many constant-from-nonconstant
+ # expressions Emboss should support. There are, after all, a vast number of
+ # constant expression patterns built from non-constant subexpressions, such as
+ # `0 * X` or `X == X` or `3 * X == X + X + X`. I (bolms@) am not implementing
+ # any further special cases because I do not see any practical use for them.
+ if function.function == ir_data.FunctionMapping.UNKNOWN:
+ return None
+ if function.function == ir_data.FunctionMapping.AND:
+ if any(value is False for value in values):
+ return False
+ elif any(value is None for value in values):
+ return None
+ else:
+ return True
+ elif function.function == ir_data.FunctionMapping.OR:
+ if any(value is True for value in values):
+ return True
+ elif any(value is None for value in values):
+ return None
+ else:
+ return False
+ elif function.function == ir_data.FunctionMapping.CHOICE:
+ if values[0] is None:
+ return None
+ else:
+ return values[1] if values[0] else values[2]
+ # Other than the logical operators and choice operator, the result of any
+ # function on an unknown value is, itself, considered unknown.
+ if any(value is None for value in values):
+ return None
+ functions = {
+ ir_data.FunctionMapping.ADDITION: operator.add,
+ ir_data.FunctionMapping.SUBTRACTION: operator.sub,
+ ir_data.FunctionMapping.MULTIPLICATION: operator.mul,
+ ir_data.FunctionMapping.EQUALITY: operator.eq,
+ ir_data.FunctionMapping.INEQUALITY: operator.ne,
+ ir_data.FunctionMapping.LESS: operator.lt,
+ ir_data.FunctionMapping.LESS_OR_EQUAL: operator.le,
+ ir_data.FunctionMapping.GREATER: operator.gt,
+ ir_data.FunctionMapping.GREATER_OR_EQUAL: operator.ge,
+ # Python's max([1, 2]) == 2; max(1, 2) == 2; max([1]) == 1; but max(1)
+ # throws a TypeError ("'int' object is not iterable").
+ ir_data.FunctionMapping.MAXIMUM: lambda *x: max(x),
+ }
+ return functions[function.function](*values)
def _hashable_form_of_name(name):
- return (name.module_file,) + tuple(name.object_path)
+ return (name.module_file,) + tuple(name.object_path)
def hashable_form_of_reference(reference):
- """Returns a representation of reference that can be used as a dict key.
+ """Returns a representation of reference that can be used as a dict key.
- Arguments:
- reference: An ir_data.Reference or ir_data.NameDefinition.
+ Arguments:
+ reference: An ir_data.Reference or ir_data.NameDefinition.
- Returns:
- A tuple of the module_file and object_path.
- """
- return _hashable_form_of_name(reference.canonical_name)
+ Returns:
+ A tuple of the module_file and object_path.
+ """
+ return _hashable_form_of_name(reference.canonical_name)
def hashable_form_of_field_reference(field_reference):
- """Returns a representation of field_reference that can be used as a dict key.
+ """Returns a representation of field_reference that can be used as a dict key.
- Arguments:
- field_reference: An ir_data.FieldReference
+ Arguments:
+ field_reference: An ir_data.FieldReference
- Returns:
- A tuple of tuples of the module_files and object_paths.
- """
- return tuple(_hashable_form_of_name(reference.canonical_name)
- for reference in field_reference.path)
+ Returns:
+ A tuple of tuples of the module_files and object_paths.
+ """
+ return tuple(
+ _hashable_form_of_name(reference.canonical_name)
+ for reference in field_reference.path
+ )
def is_array(type_ir):
- """Returns true if type_ir is an array type."""
- return type_ir.HasField("array_type")
+ """Returns true if type_ir is an array type."""
+ return type_ir.HasField("array_type")
def _find_path_in_structure_field(path, field):
- if not path:
- return field
- return None
+ if not path:
+ return field
+ return None
def _find_path_in_structure(path, type_definition):
- for field in type_definition.structure.field:
- if field.name.name.text == path[0]:
- return _find_path_in_structure_field(path[1:], field)
- return None
+ for field in type_definition.structure.field:
+ if field.name.name.text == path[0]:
+ return _find_path_in_structure_field(path[1:], field)
+ return None
def _find_path_in_enumeration(path, type_definition):
- if len(path) != 1:
+ if len(path) != 1:
+ return None
+ for value in type_definition.enumeration.value:
+ if value.name.name.text == path[0]:
+ return value
return None
- for value in type_definition.enumeration.value:
- if value.name.name.text == path[0]:
- return value
- return None
def _find_path_in_parameters(path, type_definition):
- if len(path) > 1 or not type_definition.HasField("runtime_parameter"):
+ if len(path) > 1 or not type_definition.HasField("runtime_parameter"):
+ return None
+ for parameter in type_definition.runtime_parameter:
+ if ir_data_utils.reader(parameter).name.name.text == path[0]:
+ return parameter
return None
- for parameter in type_definition.runtime_parameter:
- if ir_data_utils.reader(parameter).name.name.text == path[0]:
- return parameter
- return None
def _find_path_in_type_definition(path, type_definition):
- """Finds the object with the given path in the given type_definition."""
- if not path:
- return type_definition
- obj = _find_path_in_parameters(path, type_definition)
- if obj:
- return obj
- if type_definition.HasField("structure"):
- obj = _find_path_in_structure(path, type_definition)
- elif type_definition.HasField("enumeration"):
- obj = _find_path_in_enumeration(path, type_definition)
- if obj:
- return obj
- else:
- return _find_path_in_type_list(path, type_definition.subtype or [])
+ """Finds the object with the given path in the given type_definition."""
+ if not path:
+ return type_definition
+ obj = _find_path_in_parameters(path, type_definition)
+ if obj:
+ return obj
+ if type_definition.HasField("structure"):
+ obj = _find_path_in_structure(path, type_definition)
+ elif type_definition.HasField("enumeration"):
+ obj = _find_path_in_enumeration(path, type_definition)
+ if obj:
+ return obj
+ else:
+ return _find_path_in_type_list(path, type_definition.subtype or [])
def _find_path_in_type_list(path, type_list):
- for type_definition in type_list:
- if type_definition.name.name.text == path[0]:
- return _find_path_in_type_definition(path[1:], type_definition)
- return None
+ for type_definition in type_list:
+ if type_definition.name.name.text == path[0]:
+ return _find_path_in_type_definition(path[1:], type_definition)
+ return None
def _find_path_in_module(path, module_ir):
- if not path:
- return module_ir
- return _find_path_in_type_list(path, module_ir.type)
+ if not path:
+ return module_ir
+ return _find_path_in_type_list(path, module_ir.type)
def find_object_or_none(name, ir):
- """Finds the object with the given canonical name, if it exists.."""
- if (isinstance(name, ir_data.Reference) or
- isinstance(name, ir_data.NameDefinition)):
- path = _hashable_form_of_name(name.canonical_name)
- elif isinstance(name, ir_data.CanonicalName):
- path = _hashable_form_of_name(name)
- else:
- path = name
+ """Finds the object with the given canonical name, if it exists.."""
+ if isinstance(name, ir_data.Reference) or isinstance(name, ir_data.NameDefinition):
+ path = _hashable_form_of_name(name.canonical_name)
+ elif isinstance(name, ir_data.CanonicalName):
+ path = _hashable_form_of_name(name)
+ else:
+ path = name
- for module in ir.module:
- if module.source_file_name == path[0]:
- return _find_path_in_module(path[1:], module)
+ for module in ir.module:
+ if module.source_file_name == path[0]:
+ return _find_path_in_module(path[1:], module)
- return None
+ return None
def find_object(name, ir):
- """Finds the IR of the type, field, or value with the given canonical name."""
- result = find_object_or_none(name, ir)
- assert result is not None, "Bad reference {}".format(name)
- return result
+ """Finds the IR of the type, field, or value with the given canonical name."""
+ result = find_object_or_none(name, ir)
+ assert result is not None, "Bad reference {}".format(name)
+ return result
def find_parent_object(name, ir):
- """Finds the parent object of the object with the given canonical name."""
- if (isinstance(name, ir_data.Reference) or
- isinstance(name, ir_data.NameDefinition)):
- path = _hashable_form_of_name(name.canonical_name)
- elif isinstance(name, ir_data.CanonicalName):
- path = _hashable_form_of_name(name)
- else:
- path = name
- return find_object(path[:-1], ir)
+ """Finds the parent object of the object with the given canonical name."""
+ if isinstance(name, ir_data.Reference) or isinstance(name, ir_data.NameDefinition):
+ path = _hashable_form_of_name(name.canonical_name)
+ elif isinstance(name, ir_data.CanonicalName):
+ path = _hashable_form_of_name(name)
+ else:
+ path = name
+ return find_object(path[:-1], ir)
def get_base_type(type_ir):
- """Returns the base type of the given type.
+ """Returns the base type of the given type.
- Arguments:
- type_ir: IR of a type reference.
+ Arguments:
+ type_ir: IR of a type reference.
- Returns:
- If type_ir corresponds to an atomic type (like "UInt"), returns type_ir. If
- type_ir corresponds to an array type (like "UInt:8[12]" or "Square[8][8]"),
- returns the type after stripping off the array types ("UInt" or "Square").
- """
- while type_ir.HasField("array_type"):
- type_ir = type_ir.array_type.base_type
- assert type_ir.HasField("atomic_type"), (
- "Unknown kind of type {}".format(type_ir))
- return type_ir
+ Returns:
+ If type_ir corresponds to an atomic type (like "UInt"), returns type_ir. If
+ type_ir corresponds to an array type (like "UInt:8[12]" or "Square[8][8]"),
+ returns the type after stripping off the array types ("UInt" or "Square").
+ """
+ while type_ir.HasField("array_type"):
+ type_ir = type_ir.array_type.base_type
+ assert type_ir.HasField("atomic_type"), "Unknown kind of type {}".format(type_ir)
+ return type_ir
def fixed_size_of_type_in_bits(type_ir, ir):
- """Returns the fixed, known size for the given type, in bits, or None.
+ """Returns the fixed, known size for the given type, in bits, or None.
- Arguments:
- type_ir: The IR of a type.
- ir: A complete IR, used to resolve references to types.
+ Arguments:
+ type_ir: The IR of a type.
+ ir: A complete IR, used to resolve references to types.
- Returns:
- size if the size of the type can be determined, otherwise None.
- """
- array_multiplier = 1
- while type_ir.HasField("array_type"):
- if type_ir.array_type.WhichOneof("size") == "automatic":
- return None
+ Returns:
+ size if the size of the type can be determined, otherwise None.
+ """
+ array_multiplier = 1
+ while type_ir.HasField("array_type"):
+ if type_ir.array_type.WhichOneof("size") == "automatic":
+ return None
+ else:
+ assert (
+ type_ir.array_type.WhichOneof("size") == "element_count"
+ ), 'Expected array size to be "automatic" or "element_count".'
+ element_count = type_ir.array_type.element_count
+ if not is_constant(element_count):
+ return None
+ else:
+ array_multiplier *= constant_value(element_count)
+ assert not type_ir.HasField(
+ "size_in_bits"
+ ), "TODO(bolms): implement explicitly-sized arrays"
+ type_ir = type_ir.array_type.base_type
+ assert type_ir.HasField("atomic_type"), "Unexpected type!"
+ if type_ir.HasField("size_in_bits"):
+ size = constant_value(type_ir.size_in_bits)
else:
- assert type_ir.array_type.WhichOneof("size") == "element_count", (
- 'Expected array size to be "automatic" or "element_count".')
- element_count = type_ir.array_type.element_count
- if not is_constant(element_count):
- return None
- else:
- array_multiplier *= constant_value(element_count)
- assert not type_ir.HasField("size_in_bits"), (
- "TODO(bolms): implement explicitly-sized arrays")
- type_ir = type_ir.array_type.base_type
- assert type_ir.HasField("atomic_type"), "Unexpected type!"
- if type_ir.HasField("size_in_bits"):
- size = constant_value(type_ir.size_in_bits)
- else:
- type_definition = find_object(type_ir.atomic_type.reference, ir)
- size_attr = get_attribute(type_definition.attribute, _FIXED_SIZE_ATTRIBUTE)
- if not size_attr:
- return None
- size = constant_value(size_attr.expression)
- return size * array_multiplier
+ type_definition = find_object(type_ir.atomic_type.reference, ir)
+ size_attr = get_attribute(type_definition.attribute, _FIXED_SIZE_ATTRIBUTE)
+ if not size_attr:
+ return None
+ size = constant_value(size_attr.expression)
+ return size * array_multiplier
def field_is_virtual(field_ir):
- """Returns true if the field is virtual."""
- # TODO(bolms): Should there be a more explicit indicator that a field is
- # virtual?
- return not field_ir.HasField("location")
+ """Returns true if the field is virtual."""
+ # TODO(bolms): Should there be a more explicit indicator that a field is
+ # virtual?
+ return not field_ir.HasField("location")
def field_is_read_only(field_ir):
- """Returns true if the field is read-only."""
- # For now, all virtual fields are read-only, and no non-virtual fields are
- # read-only.
- return ir_data_utils.reader(field_ir).write_method.read_only
+ """Returns true if the field is read-only."""
+ # For now, all virtual fields are read-only, and no non-virtual fields are
+ # read-only.
+ return ir_data_utils.reader(field_ir).write_method.read_only
diff --git a/compiler/util/ir_util_test.py b/compiler/util/ir_util_test.py
index 8e0b37a..07cf4dc 100644
--- a/compiler/util/ir_util_test.py
+++ b/compiler/util/ir_util_test.py
@@ -22,397 +22,559 @@
def _parse_expression(text):
- return expression_parser.parse(text)
+ return expression_parser.parse(text)
class IrUtilTest(unittest.TestCase):
- """Tests for the miscellaneous utility functions in ir_util.py."""
+ """Tests for the miscellaneous utility functions in ir_util.py."""
- def test_is_constant_integer(self):
- self.assertTrue(ir_util.is_constant(_parse_expression("6")))
- expression = _parse_expression("12")
- # The type information should be ignored for constants like this one.
- ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType())
- self.assertTrue(ir_util.is_constant(expression))
+ def test_is_constant_integer(self):
+ self.assertTrue(ir_util.is_constant(_parse_expression("6")))
+ expression = _parse_expression("12")
+ # The type information should be ignored for constants like this one.
+ ir_data_utils.builder(expression).type.integer.CopyFrom(ir_data.IntegerType())
+ self.assertTrue(ir_util.is_constant(expression))
- def test_is_constant_boolean(self):
- self.assertTrue(ir_util.is_constant(_parse_expression("true")))
- expression = _parse_expression("true")
- # The type information should be ignored for constants like this one.
- ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType())
- self.assertTrue(ir_util.is_constant(expression))
+ def test_is_constant_boolean(self):
+ self.assertTrue(ir_util.is_constant(_parse_expression("true")))
+ expression = _parse_expression("true")
+ # The type information should be ignored for constants like this one.
+ ir_data_utils.builder(expression).type.boolean.CopyFrom(ir_data.BooleanType())
+ self.assertTrue(ir_util.is_constant(expression))
- def test_is_constant_enum(self):
- self.assertTrue(ir_util.is_constant(ir_data.Expression(
- constant_reference=ir_data.Reference(),
- type=ir_data.ExpressionType(enumeration=ir_data.EnumType(value="12")))))
+ def test_is_constant_enum(self):
+ self.assertTrue(
+ ir_util.is_constant(
+ ir_data.Expression(
+ constant_reference=ir_data.Reference(),
+ type=ir_data.ExpressionType(
+ enumeration=ir_data.EnumType(value="12")
+ ),
+ )
+ )
+ )
- def test_is_constant_integer_type(self):
- self.assertFalse(ir_util.is_constant_type(ir_data.ExpressionType(
- integer=ir_data.IntegerType(
- modulus="10",
- modular_value="5",
- minimum_value="-5",
- maximum_value="15"))))
- self.assertTrue(ir_util.is_constant_type(ir_data.ExpressionType(
- integer=ir_data.IntegerType(
- modulus="infinity",
- modular_value="5",
- minimum_value="5",
- maximum_value="5"))))
-
- def test_is_constant_boolean_type(self):
- self.assertFalse(ir_util.is_constant_type(ir_data.ExpressionType(
- boolean=ir_data.BooleanType())))
- self.assertTrue(ir_util.is_constant_type(ir_data.ExpressionType(
- boolean=ir_data.BooleanType(value=True))))
- self.assertTrue(ir_util.is_constant_type(ir_data.ExpressionType(
- boolean=ir_data.BooleanType(value=False))))
-
- def test_is_constant_enumeration_type(self):
- self.assertFalse(ir_util.is_constant_type(ir_data.ExpressionType(
- enumeration=ir_data.EnumType())))
- self.assertTrue(ir_util.is_constant_type(ir_data.ExpressionType(
- enumeration=ir_data.EnumType(value="0"))))
-
- def test_is_constant_opaque_type(self):
- self.assertFalse(ir_util.is_constant_type(ir_data.ExpressionType(
- opaque=ir_data.OpaqueType())))
-
- def test_constant_value_of_integer(self):
- self.assertEqual(6, ir_util.constant_value(_parse_expression("6")))
-
- def test_constant_value_of_none(self):
- self.assertIsNone(ir_util.constant_value(ir_data.Expression()))
-
- def test_constant_value_of_addition(self):
- self.assertEqual(6, ir_util.constant_value(_parse_expression("2+4")))
-
- def test_constant_value_of_subtraction(self):
- self.assertEqual(-2, ir_util.constant_value(_parse_expression("2-4")))
-
- def test_constant_value_of_multiplication(self):
- self.assertEqual(8, ir_util.constant_value(_parse_expression("2*4")))
-
- def test_constant_value_of_equality(self):
- self.assertFalse(ir_util.constant_value(_parse_expression("2 == 4")))
-
- def test_constant_value_of_inequality(self):
- self.assertTrue(ir_util.constant_value(_parse_expression("2 != 4")))
-
- def test_constant_value_of_less(self):
- self.assertTrue(ir_util.constant_value(_parse_expression("2 < 4")))
-
- def test_constant_value_of_less_or_equal(self):
- self.assertTrue(ir_util.constant_value(_parse_expression("2 <= 4")))
-
- def test_constant_value_of_greater(self):
- self.assertFalse(ir_util.constant_value(_parse_expression("2 > 4")))
-
- def test_constant_value_of_greater_or_equal(self):
- self.assertFalse(ir_util.constant_value(_parse_expression("2 >= 4")))
-
- def test_constant_value_of_and(self):
- self.assertFalse(ir_util.constant_value(_parse_expression("true && false")))
- self.assertTrue(ir_util.constant_value(_parse_expression("true && true")))
-
- def test_constant_value_of_or(self):
- self.assertTrue(ir_util.constant_value(_parse_expression("true || false")))
- self.assertFalse(
- ir_util.constant_value(_parse_expression("false || false")))
-
- def test_constant_value_of_choice(self):
- self.assertEqual(
- 10, ir_util.constant_value(_parse_expression("false ? 20 : 10")))
- self.assertEqual(
- 20, ir_util.constant_value(_parse_expression("true ? 20 : 10")))
-
- def test_constant_value_of_choice_with_unknown_other_branch(self):
- self.assertEqual(
- 10, ir_util.constant_value(_parse_expression("false ? foo : 10")))
- self.assertEqual(
- 20, ir_util.constant_value(_parse_expression("true ? 20 : foo")))
-
- def test_constant_value_of_maximum(self):
- self.assertEqual(10,
- ir_util.constant_value(_parse_expression("$max(5, 10)")))
- self.assertEqual(10,
- ir_util.constant_value(_parse_expression("$max(10)")))
- self.assertEqual(
- 10,
- ir_util.constant_value(_parse_expression("$max(5, 9, 7, 10, 6, -100)")))
-
- def test_constant_value_of_boolean(self):
- self.assertTrue(ir_util.constant_value(_parse_expression("true")))
- self.assertFalse(ir_util.constant_value(_parse_expression("false")))
-
- def test_constant_value_of_enum(self):
- self.assertEqual(12, ir_util.constant_value(ir_data.Expression(
- constant_reference=ir_data.Reference(),
- type=ir_data.ExpressionType(enumeration=ir_data.EnumType(value="12")))))
-
- def test_constant_value_of_integer_reference(self):
- self.assertEqual(12, ir_util.constant_value(ir_data.Expression(
- constant_reference=ir_data.Reference(),
- type=ir_data.ExpressionType(
- integer=ir_data.IntegerType(modulus="infinity",
- modular_value="12")))))
-
- def test_constant_value_of_boolean_reference(self):
- self.assertTrue(ir_util.constant_value(ir_data.Expression(
- constant_reference=ir_data.Reference(),
- type=ir_data.ExpressionType(boolean=ir_data.BooleanType(value=True)))))
-
- def test_constant_value_of_builtin_reference(self):
- self.assertEqual(12, ir_util.constant_value(
- ir_data.Expression(
- builtin_reference=ir_data.Reference(
- canonical_name=ir_data.CanonicalName(object_path=["$foo"]))),
- {"$foo": 12}))
-
- def test_constant_value_of_field_reference(self):
- self.assertIsNone(ir_util.constant_value(_parse_expression("foo")))
-
- def test_constant_value_of_missing_builtin_reference(self):
- self.assertIsNone(ir_util.constant_value(
- _parse_expression("$static_size_in_bits"), {"$bar": 12}))
-
- def test_constant_value_of_present_builtin_reference(self):
- self.assertEqual(12, ir_util.constant_value(
- _parse_expression("$static_size_in_bits"),
- {"$static_size_in_bits": 12}))
-
- def test_constant_false_value_of_operator_and_with_missing_value(self):
- self.assertIs(False, ir_util.constant_value(
- _parse_expression("false && foo"), {"bar": 12}))
- self.assertIs(False, ir_util.constant_value(
- _parse_expression("foo && false"), {"bar": 12}))
-
- def test_constant_false_value_of_operator_and_known_value(self):
- self.assertTrue(ir_util.constant_value(
- _parse_expression("true && $is_statically_sized"),
- {"$is_statically_sized": True}))
-
- def test_constant_none_value_of_operator_and_with_missing_value(self):
- self.assertIsNone(ir_util.constant_value(
- _parse_expression("true && foo"), {"bar": 12}))
- self.assertIsNone(ir_util.constant_value(
- _parse_expression("foo && true"), {"bar": 12}))
-
- def test_constant_false_value_of_operator_or_with_missing_value(self):
- self.assertTrue(ir_util.constant_value(
- _parse_expression("true || foo"), {"bar": 12}))
- self.assertTrue(ir_util.constant_value(
- _parse_expression("foo || true"), {"bar": 12}))
-
- def test_constant_none_value_of_operator_or_with_missing_value(self):
- self.assertIsNone(ir_util.constant_value(
- _parse_expression("foo || false"), {"bar": 12}))
- self.assertIsNone(ir_util.constant_value(
- _parse_expression("false || foo"), {"bar": 12}))
-
- def test_constant_value_of_operator_plus_with_missing_value(self):
- self.assertIsNone(ir_util.constant_value(
- _parse_expression("12 + foo"), {"bar": 12}))
-
- def test_is_array(self):
- self.assertTrue(
- ir_util.is_array(ir_data.Type(array_type=ir_data.ArrayType())))
- self.assertFalse(
- ir_util.is_array(ir_data.Type(atomic_type=ir_data.AtomicType())))
-
- def test_get_attribute(self):
- type_def = ir_data.TypeDefinition(attribute=[
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("false")),
- name=ir_data.Word(text="bob"),
- is_default=True),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("true")),
- name=ir_data.Word(text="bob")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("false")),
- name=ir_data.Word(text="bob2")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("true")),
- name=ir_data.Word(text="bob2"),
- is_default=True),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("false")),
- name=ir_data.Word(text="bob3"),
- is_default=True),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("false")),
- name=ir_data.Word()),
- ])
- self.assertEqual(
- ir_data.AttributeValue(expression=_parse_expression("true")),
- ir_util.get_attribute(type_def.attribute, "bob"))
- self.assertEqual(
- ir_data.AttributeValue(expression=_parse_expression("false")),
- ir_util.get_attribute(type_def.attribute, "bob2"))
- self.assertEqual(None, ir_util.get_attribute(type_def.attribute, "Bob"))
- self.assertEqual(None, ir_util.get_attribute(type_def.attribute, "bob3"))
-
- def test_get_boolean_attribute(self):
- type_def = ir_data.TypeDefinition(attribute=[
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("false")),
- name=ir_data.Word(text="bob"),
- is_default=True),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("true")),
- name=ir_data.Word(text="bob")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("false")),
- name=ir_data.Word(text="bob2")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("true")),
- name=ir_data.Word(text="bob2"),
- is_default=True),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("false")),
- name=ir_data.Word(text="bob3"),
- is_default=True),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("false")),
- name=ir_data.Word()),
- ])
- self.assertTrue(ir_util.get_boolean_attribute(type_def.attribute, "bob"))
- self.assertTrue(ir_util.get_boolean_attribute(type_def.attribute,
- "bob",
- default_value=False))
- self.assertFalse(ir_util.get_boolean_attribute(type_def.attribute, "bob2"))
- self.assertFalse(ir_util.get_boolean_attribute(type_def.attribute,
- "bob2",
- default_value=True))
- self.assertIsNone(ir_util.get_boolean_attribute(type_def.attribute, "Bob"))
- self.assertTrue(ir_util.get_boolean_attribute(type_def.attribute,
- "Bob",
- default_value=True))
- self.assertIsNone(ir_util.get_boolean_attribute(type_def.attribute, "bob3"))
-
- def test_get_integer_attribute(self):
- type_def = ir_data.TypeDefinition(attribute=[
- ir_data.Attribute(
- value=ir_data.AttributeValue(
- expression=ir_data.Expression(
- type=ir_data.ExpressionType(integer=ir_data.IntegerType()))),
- name=ir_data.Word(text="phil")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(
- expression=ir_data.Expression(
- constant=ir_data.NumericConstant(value="20"),
- type=ir_data.ExpressionType(integer=ir_data.IntegerType(
- modular_value="20",
- modulus="infinity")))),
- name=ir_data.Word(text="bob"),
- is_default=True),
- ir_data.Attribute(
- value=ir_data.AttributeValue(
- expression=ir_data.Expression(
- constant=ir_data.NumericConstant(value="10"),
- type=ir_data.ExpressionType(integer=ir_data.IntegerType(
- modular_value="10",
- modulus="infinity")))),
- name=ir_data.Word(text="bob")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(
- expression=ir_data.Expression(
- constant=ir_data.NumericConstant(value="5"),
- type=ir_data.ExpressionType(integer=ir_data.IntegerType(
+ def test_is_constant_integer_type(self):
+ self.assertFalse(
+ ir_util.is_constant_type(
+ ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modulus="10",
modular_value="5",
- modulus="infinity")))),
- name=ir_data.Word(text="bob2")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(
- expression=ir_data.Expression(
- constant=ir_data.NumericConstant(value="0"),
- type=ir_data.ExpressionType(integer=ir_data.IntegerType(
- modular_value="0",
- modulus="infinity")))),
- name=ir_data.Word(text="bob2"),
- is_default=True),
- ir_data.Attribute(
- value=ir_data.AttributeValue(
- expression=ir_data.Expression(
- constant=ir_data.NumericConstant(value="30"),
- type=ir_data.ExpressionType(integer=ir_data.IntegerType(
- modular_value="30",
- modulus="infinity")))),
- name=ir_data.Word(text="bob3"),
- is_default=True),
- ir_data.Attribute(
- value=ir_data.AttributeValue(
- expression=ir_data.Expression(
- function=ir_data.Function(
- function=ir_data.FunctionMapping.ADDITION,
- args=[
- ir_data.Expression(
- constant=ir_data.NumericConstant(value="100"),
- type=ir_data.ExpressionType(
- integer=ir_data.IntegerType(
- modular_value="100",
- modulus="infinity"))),
- ir_data.Expression(
- constant=ir_data.NumericConstant(value="100"),
- type=ir_data.ExpressionType(
- integer=ir_data.IntegerType(
- modular_value="100",
- modulus="infinity")))
- ]),
- type=ir_data.ExpressionType(integer=ir_data.IntegerType(
- modular_value="200",
- modulus="infinity")))),
- name=ir_data.Word(text="bob4")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(
- expression=ir_data.Expression(
- constant=ir_data.NumericConstant(value="40"),
- type=ir_data.ExpressionType(integer=ir_data.IntegerType(
- modular_value="40",
- modulus="infinity")))),
- name=ir_data.Word()),
- ])
- self.assertEqual(10,
- ir_util.get_integer_attribute(type_def.attribute, "bob"))
- self.assertEqual(5,
- ir_util.get_integer_attribute(type_def.attribute, "bob2"))
- self.assertIsNone(ir_util.get_integer_attribute(type_def.attribute, "Bob"))
- self.assertEqual(10, ir_util.get_integer_attribute(type_def.attribute,
- "Bob",
- default_value=10))
- self.assertIsNone(ir_util.get_integer_attribute(type_def.attribute, "bob3"))
- self.assertEqual(200, ir_util.get_integer_attribute(type_def.attribute,
- "bob4"))
+ minimum_value="-5",
+ maximum_value="15",
+ )
+ )
+ )
+ )
+ self.assertTrue(
+ ir_util.is_constant_type(
+ ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modulus="infinity",
+ modular_value="5",
+ minimum_value="5",
+ maximum_value="5",
+ )
+ )
+ )
+ )
- def test_get_duplicate_attribute(self):
- type_def = ir_data.TypeDefinition(attribute=[
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=ir_data.Expression()),
- name=ir_data.Word(text="phil")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("true")),
- name=ir_data.Word(text="bob")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("false")),
- name=ir_data.Word(text="bob")),
- ir_data.Attribute(
- value=ir_data.AttributeValue(expression=_parse_expression("false")),
- name=ir_data.Word()),
- ])
- self.assertRaises(AssertionError, ir_util.get_attribute, type_def.attribute,
- "bob")
+ def test_is_constant_boolean_type(self):
+ self.assertFalse(
+ ir_util.is_constant_type(
+ ir_data.ExpressionType(boolean=ir_data.BooleanType())
+ )
+ )
+ self.assertTrue(
+ ir_util.is_constant_type(
+ ir_data.ExpressionType(boolean=ir_data.BooleanType(value=True))
+ )
+ )
+ self.assertTrue(
+ ir_util.is_constant_type(
+ ir_data.ExpressionType(boolean=ir_data.BooleanType(value=False))
+ )
+ )
- def test_find_object(self):
- ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr,
- """{
+ def test_is_constant_enumeration_type(self):
+ self.assertFalse(
+ ir_util.is_constant_type(
+ ir_data.ExpressionType(enumeration=ir_data.EnumType())
+ )
+ )
+ self.assertTrue(
+ ir_util.is_constant_type(
+ ir_data.ExpressionType(enumeration=ir_data.EnumType(value="0"))
+ )
+ )
+
+ def test_is_constant_opaque_type(self):
+ self.assertFalse(
+ ir_util.is_constant_type(
+ ir_data.ExpressionType(opaque=ir_data.OpaqueType())
+ )
+ )
+
+ def test_constant_value_of_integer(self):
+ self.assertEqual(6, ir_util.constant_value(_parse_expression("6")))
+
+ def test_constant_value_of_none(self):
+ self.assertIsNone(ir_util.constant_value(ir_data.Expression()))
+
+ def test_constant_value_of_addition(self):
+ self.assertEqual(6, ir_util.constant_value(_parse_expression("2+4")))
+
+ def test_constant_value_of_subtraction(self):
+ self.assertEqual(-2, ir_util.constant_value(_parse_expression("2-4")))
+
+ def test_constant_value_of_multiplication(self):
+ self.assertEqual(8, ir_util.constant_value(_parse_expression("2*4")))
+
+ def test_constant_value_of_equality(self):
+ self.assertFalse(ir_util.constant_value(_parse_expression("2 == 4")))
+
+ def test_constant_value_of_inequality(self):
+ self.assertTrue(ir_util.constant_value(_parse_expression("2 != 4")))
+
+ def test_constant_value_of_less(self):
+ self.assertTrue(ir_util.constant_value(_parse_expression("2 < 4")))
+
+ def test_constant_value_of_less_or_equal(self):
+ self.assertTrue(ir_util.constant_value(_parse_expression("2 <= 4")))
+
+ def test_constant_value_of_greater(self):
+ self.assertFalse(ir_util.constant_value(_parse_expression("2 > 4")))
+
+ def test_constant_value_of_greater_or_equal(self):
+ self.assertFalse(ir_util.constant_value(_parse_expression("2 >= 4")))
+
+ def test_constant_value_of_and(self):
+ self.assertFalse(ir_util.constant_value(_parse_expression("true && false")))
+ self.assertTrue(ir_util.constant_value(_parse_expression("true && true")))
+
+ def test_constant_value_of_or(self):
+ self.assertTrue(ir_util.constant_value(_parse_expression("true || false")))
+ self.assertFalse(ir_util.constant_value(_parse_expression("false || false")))
+
+ def test_constant_value_of_choice(self):
+ self.assertEqual(
+ 10, ir_util.constant_value(_parse_expression("false ? 20 : 10"))
+ )
+ self.assertEqual(
+ 20, ir_util.constant_value(_parse_expression("true ? 20 : 10"))
+ )
+
+ def test_constant_value_of_choice_with_unknown_other_branch(self):
+ self.assertEqual(
+ 10, ir_util.constant_value(_parse_expression("false ? foo : 10"))
+ )
+ self.assertEqual(
+ 20, ir_util.constant_value(_parse_expression("true ? 20 : foo"))
+ )
+
+ def test_constant_value_of_maximum(self):
+ self.assertEqual(10, ir_util.constant_value(_parse_expression("$max(5, 10)")))
+ self.assertEqual(10, ir_util.constant_value(_parse_expression("$max(10)")))
+ self.assertEqual(
+ 10, ir_util.constant_value(_parse_expression("$max(5, 9, 7, 10, 6, -100)"))
+ )
+
+ def test_constant_value_of_boolean(self):
+ self.assertTrue(ir_util.constant_value(_parse_expression("true")))
+ self.assertFalse(ir_util.constant_value(_parse_expression("false")))
+
+ def test_constant_value_of_enum(self):
+ self.assertEqual(
+ 12,
+ ir_util.constant_value(
+ ir_data.Expression(
+ constant_reference=ir_data.Reference(),
+ type=ir_data.ExpressionType(
+ enumeration=ir_data.EnumType(value="12")
+ ),
+ )
+ ),
+ )
+
+ def test_constant_value_of_integer_reference(self):
+ self.assertEqual(
+ 12,
+ ir_util.constant_value(
+ ir_data.Expression(
+ constant_reference=ir_data.Reference(),
+ type=ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modulus="infinity", modular_value="12"
+ )
+ ),
+ )
+ ),
+ )
+
+ def test_constant_value_of_boolean_reference(self):
+ self.assertTrue(
+ ir_util.constant_value(
+ ir_data.Expression(
+ constant_reference=ir_data.Reference(),
+ type=ir_data.ExpressionType(
+ boolean=ir_data.BooleanType(value=True)
+ ),
+ )
+ )
+ )
+
+ def test_constant_value_of_builtin_reference(self):
+ self.assertEqual(
+ 12,
+ ir_util.constant_value(
+ ir_data.Expression(
+ builtin_reference=ir_data.Reference(
+ canonical_name=ir_data.CanonicalName(object_path=["$foo"])
+ )
+ ),
+ {"$foo": 12},
+ ),
+ )
+
+ def test_constant_value_of_field_reference(self):
+ self.assertIsNone(ir_util.constant_value(_parse_expression("foo")))
+
+ def test_constant_value_of_missing_builtin_reference(self):
+ self.assertIsNone(
+ ir_util.constant_value(
+ _parse_expression("$static_size_in_bits"), {"$bar": 12}
+ )
+ )
+
+ def test_constant_value_of_present_builtin_reference(self):
+ self.assertEqual(
+ 12,
+ ir_util.constant_value(
+ _parse_expression("$static_size_in_bits"), {"$static_size_in_bits": 12}
+ ),
+ )
+
+ def test_constant_false_value_of_operator_and_with_missing_value(self):
+ self.assertIs(
+ False,
+ ir_util.constant_value(_parse_expression("false && foo"), {"bar": 12}),
+ )
+ self.assertIs(
+ False,
+ ir_util.constant_value(_parse_expression("foo && false"), {"bar": 12}),
+ )
+
+ def test_constant_false_value_of_operator_and_known_value(self):
+ self.assertTrue(
+ ir_util.constant_value(
+ _parse_expression("true && $is_statically_sized"),
+ {"$is_statically_sized": True},
+ )
+ )
+
+ def test_constant_none_value_of_operator_and_with_missing_value(self):
+ self.assertIsNone(
+ ir_util.constant_value(_parse_expression("true && foo"), {"bar": 12})
+ )
+ self.assertIsNone(
+ ir_util.constant_value(_parse_expression("foo && true"), {"bar": 12})
+ )
+
+ def test_constant_false_value_of_operator_or_with_missing_value(self):
+ self.assertTrue(
+ ir_util.constant_value(_parse_expression("true || foo"), {"bar": 12})
+ )
+ self.assertTrue(
+ ir_util.constant_value(_parse_expression("foo || true"), {"bar": 12})
+ )
+
+ def test_constant_none_value_of_operator_or_with_missing_value(self):
+ self.assertIsNone(
+ ir_util.constant_value(_parse_expression("foo || false"), {"bar": 12})
+ )
+ self.assertIsNone(
+ ir_util.constant_value(_parse_expression("false || foo"), {"bar": 12})
+ )
+
+ def test_constant_value_of_operator_plus_with_missing_value(self):
+ self.assertIsNone(
+ ir_util.constant_value(_parse_expression("12 + foo"), {"bar": 12})
+ )
+
+ def test_is_array(self):
+ self.assertTrue(ir_util.is_array(ir_data.Type(array_type=ir_data.ArrayType())))
+ self.assertFalse(
+ ir_util.is_array(ir_data.Type(atomic_type=ir_data.AtomicType()))
+ )
+
+ def test_get_attribute(self):
+ type_def = ir_data.TypeDefinition(
+ attribute=[
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("false")),
+ name=ir_data.Word(text="bob"),
+ is_default=True,
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("true")),
+ name=ir_data.Word(text="bob"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("false")),
+ name=ir_data.Word(text="bob2"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("true")),
+ name=ir_data.Word(text="bob2"),
+ is_default=True,
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("false")),
+ name=ir_data.Word(text="bob3"),
+ is_default=True,
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("false")),
+ name=ir_data.Word(),
+ ),
+ ]
+ )
+ self.assertEqual(
+ ir_data.AttributeValue(expression=_parse_expression("true")),
+ ir_util.get_attribute(type_def.attribute, "bob"),
+ )
+ self.assertEqual(
+ ir_data.AttributeValue(expression=_parse_expression("false")),
+ ir_util.get_attribute(type_def.attribute, "bob2"),
+ )
+ self.assertEqual(None, ir_util.get_attribute(type_def.attribute, "Bob"))
+ self.assertEqual(None, ir_util.get_attribute(type_def.attribute, "bob3"))
+
+ def test_get_boolean_attribute(self):
+ type_def = ir_data.TypeDefinition(
+ attribute=[
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("false")),
+ name=ir_data.Word(text="bob"),
+ is_default=True,
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("true")),
+ name=ir_data.Word(text="bob"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("false")),
+ name=ir_data.Word(text="bob2"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("true")),
+ name=ir_data.Word(text="bob2"),
+ is_default=True,
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("false")),
+ name=ir_data.Word(text="bob3"),
+ is_default=True,
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("false")),
+ name=ir_data.Word(),
+ ),
+ ]
+ )
+ self.assertTrue(ir_util.get_boolean_attribute(type_def.attribute, "bob"))
+ self.assertTrue(
+ ir_util.get_boolean_attribute(
+ type_def.attribute, "bob", default_value=False
+ )
+ )
+ self.assertFalse(ir_util.get_boolean_attribute(type_def.attribute, "bob2"))
+ self.assertFalse(
+ ir_util.get_boolean_attribute(
+ type_def.attribute, "bob2", default_value=True
+ )
+ )
+ self.assertIsNone(ir_util.get_boolean_attribute(type_def.attribute, "Bob"))
+ self.assertTrue(
+ ir_util.get_boolean_attribute(type_def.attribute, "Bob", default_value=True)
+ )
+ self.assertIsNone(ir_util.get_boolean_attribute(type_def.attribute, "bob3"))
+
+ def test_get_integer_attribute(self):
+ type_def = ir_data.TypeDefinition(
+ attribute=[
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(
+ expression=ir_data.Expression(
+ type=ir_data.ExpressionType(integer=ir_data.IntegerType())
+ )
+ ),
+ name=ir_data.Word(text="phil"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(
+ expression=ir_data.Expression(
+ constant=ir_data.NumericConstant(value="20"),
+ type=ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modular_value="20", modulus="infinity"
+ )
+ ),
+ )
+ ),
+ name=ir_data.Word(text="bob"),
+ is_default=True,
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(
+ expression=ir_data.Expression(
+ constant=ir_data.NumericConstant(value="10"),
+ type=ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modular_value="10", modulus="infinity"
+ )
+ ),
+ )
+ ),
+ name=ir_data.Word(text="bob"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(
+ expression=ir_data.Expression(
+ constant=ir_data.NumericConstant(value="5"),
+ type=ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modular_value="5", modulus="infinity"
+ )
+ ),
+ )
+ ),
+ name=ir_data.Word(text="bob2"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(
+ expression=ir_data.Expression(
+ constant=ir_data.NumericConstant(value="0"),
+ type=ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modular_value="0", modulus="infinity"
+ )
+ ),
+ )
+ ),
+ name=ir_data.Word(text="bob2"),
+ is_default=True,
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(
+ expression=ir_data.Expression(
+ constant=ir_data.NumericConstant(value="30"),
+ type=ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modular_value="30", modulus="infinity"
+ )
+ ),
+ )
+ ),
+ name=ir_data.Word(text="bob3"),
+ is_default=True,
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(
+ expression=ir_data.Expression(
+ function=ir_data.Function(
+ function=ir_data.FunctionMapping.ADDITION,
+ args=[
+ ir_data.Expression(
+ constant=ir_data.NumericConstant(value="100"),
+ type=ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modular_value="100", modulus="infinity"
+ )
+ ),
+ ),
+ ir_data.Expression(
+ constant=ir_data.NumericConstant(value="100"),
+ type=ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modular_value="100", modulus="infinity"
+ )
+ ),
+ ),
+ ],
+ ),
+ type=ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modular_value="200", modulus="infinity"
+ )
+ ),
+ )
+ ),
+ name=ir_data.Word(text="bob4"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(
+ expression=ir_data.Expression(
+ constant=ir_data.NumericConstant(value="40"),
+ type=ir_data.ExpressionType(
+ integer=ir_data.IntegerType(
+ modular_value="40", modulus="infinity"
+ )
+ ),
+ )
+ ),
+ name=ir_data.Word(),
+ ),
+ ]
+ )
+ self.assertEqual(10, ir_util.get_integer_attribute(type_def.attribute, "bob"))
+ self.assertEqual(5, ir_util.get_integer_attribute(type_def.attribute, "bob2"))
+ self.assertIsNone(ir_util.get_integer_attribute(type_def.attribute, "Bob"))
+ self.assertEqual(
+ 10,
+ ir_util.get_integer_attribute(type_def.attribute, "Bob", default_value=10),
+ )
+ self.assertIsNone(ir_util.get_integer_attribute(type_def.attribute, "bob3"))
+ self.assertEqual(200, ir_util.get_integer_attribute(type_def.attribute, "bob4"))
+
+ def test_get_duplicate_attribute(self):
+ type_def = ir_data.TypeDefinition(
+ attribute=[
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("true")),
+ name=ir_data.Word(text="bob"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("false")),
+ name=ir_data.Word(text="bob"),
+ ),
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=_parse_expression("false")),
+ name=ir_data.Word(),
+ ),
+ ]
+ )
+ self.assertRaises(
+ AssertionError, ir_util.get_attribute, type_def.attribute, "bob"
+ )
+
+ def test_find_object(self):
+ ir = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.EmbossIr,
+ """{
"module": [
{
"type": [
@@ -490,83 +652,128 @@
"source_file_name": ""
}
]
- }""")
+ }""",
+ )
- # Test that find_object works with any of its four "name" types.
- canonical_name_of_foo = ir_data.CanonicalName(module_file="test.emb",
- object_path=["Foo"])
- self.assertEqual(ir.module[0].type[0], ir_util.find_object(
- ir_data.Reference(canonical_name=canonical_name_of_foo), ir))
- self.assertEqual(ir.module[0].type[0], ir_util.find_object(
- ir_data.NameDefinition(canonical_name=canonical_name_of_foo), ir))
- self.assertEqual(ir.module[0].type[0],
- ir_util.find_object(canonical_name_of_foo, ir))
- self.assertEqual(ir.module[0].type[0],
- ir_util.find_object(("test.emb", "Foo"), ir))
+ # Test that find_object works with any of its four "name" types.
+ canonical_name_of_foo = ir_data.CanonicalName(
+ module_file="test.emb", object_path=["Foo"]
+ )
+ self.assertEqual(
+ ir.module[0].type[0],
+ ir_util.find_object(
+ ir_data.Reference(canonical_name=canonical_name_of_foo), ir
+ ),
+ )
+ self.assertEqual(
+ ir.module[0].type[0],
+ ir_util.find_object(
+ ir_data.NameDefinition(canonical_name=canonical_name_of_foo), ir
+ ),
+ )
+ self.assertEqual(
+ ir.module[0].type[0], ir_util.find_object(canonical_name_of_foo, ir)
+ )
+ self.assertEqual(
+ ir.module[0].type[0], ir_util.find_object(("test.emb", "Foo"), ir)
+ )
- # Test that find_object works with objects other than structs.
- self.assertEqual(ir.module[0].type[1],
- ir_util.find_object(("test.emb", "Bar"), ir))
- self.assertEqual(ir.module[1].type[0],
- ir_util.find_object(("", "UInt"), ir))
- self.assertEqual(ir.module[0].type[0].structure.field[0],
- ir_util.find_object(("test.emb", "Foo", "field"), ir))
- self.assertEqual(ir.module[0].type[0].runtime_parameter[0],
- ir_util.find_object(("test.emb", "Foo", "parameter"), ir))
- self.assertEqual(ir.module[0].type[1].enumeration.value[0],
- ir_util.find_object(("test.emb", "Bar", "QUX"), ir))
- self.assertEqual(ir.module[0], ir_util.find_object(("test.emb",), ir))
- self.assertEqual(ir.module[1], ir_util.find_object(("",), ir))
+ # Test that find_object works with objects other than structs.
+ self.assertEqual(
+ ir.module[0].type[1], ir_util.find_object(("test.emb", "Bar"), ir)
+ )
+ self.assertEqual(ir.module[1].type[0], ir_util.find_object(("", "UInt"), ir))
+ self.assertEqual(
+ ir.module[0].type[0].structure.field[0],
+ ir_util.find_object(("test.emb", "Foo", "field"), ir),
+ )
+ self.assertEqual(
+ ir.module[0].type[0].runtime_parameter[0],
+ ir_util.find_object(("test.emb", "Foo", "parameter"), ir),
+ )
+ self.assertEqual(
+ ir.module[0].type[1].enumeration.value[0],
+ ir_util.find_object(("test.emb", "Bar", "QUX"), ir),
+ )
+ self.assertEqual(ir.module[0], ir_util.find_object(("test.emb",), ir))
+ self.assertEqual(ir.module[1], ir_util.find_object(("",), ir))
- # Test searching for non-present objects.
- self.assertIsNone(ir_util.find_object_or_none(("not_test.emb",), ir))
- self.assertIsNone(ir_util.find_object_or_none(("test.emb", "NotFoo"), ir))
- self.assertIsNone(
- ir_util.find_object_or_none(("test.emb", "Foo", "not_field"), ir))
- self.assertIsNone(
- ir_util.find_object_or_none(("test.emb", "Foo", "field", "no_subfield"),
- ir))
- self.assertIsNone(
- ir_util.find_object_or_none(("test.emb", "Bar", "NOT_QUX"), ir))
- self.assertIsNone(
- ir_util.find_object_or_none(("test.emb", "Bar", "QUX", "no_subenum"),
- ir))
+ # Test searching for non-present objects.
+ self.assertIsNone(ir_util.find_object_or_none(("not_test.emb",), ir))
+ self.assertIsNone(ir_util.find_object_or_none(("test.emb", "NotFoo"), ir))
+ self.assertIsNone(
+ ir_util.find_object_or_none(("test.emb", "Foo", "not_field"), ir)
+ )
+ self.assertIsNone(
+ ir_util.find_object_or_none(("test.emb", "Foo", "field", "no_subfield"), ir)
+ )
+ self.assertIsNone(
+ ir_util.find_object_or_none(("test.emb", "Bar", "NOT_QUX"), ir)
+ )
+ self.assertIsNone(
+ ir_util.find_object_or_none(("test.emb", "Bar", "QUX", "no_subenum"), ir)
+ )
- # Test that find_parent_object works with any of its four "name" types.
- self.assertEqual(ir.module[0], ir_util.find_parent_object(
- ir_data.Reference(canonical_name=canonical_name_of_foo), ir))
- self.assertEqual(ir.module[0], ir_util.find_parent_object(
- ir_data.NameDefinition(canonical_name=canonical_name_of_foo), ir))
- self.assertEqual(ir.module[0],
- ir_util.find_parent_object(canonical_name_of_foo, ir))
- self.assertEqual(ir.module[0],
- ir_util.find_parent_object(("test.emb", "Foo"), ir))
+ # Test that find_parent_object works with any of its four "name" types.
+ self.assertEqual(
+ ir.module[0],
+ ir_util.find_parent_object(
+ ir_data.Reference(canonical_name=canonical_name_of_foo), ir
+ ),
+ )
+ self.assertEqual(
+ ir.module[0],
+ ir_util.find_parent_object(
+ ir_data.NameDefinition(canonical_name=canonical_name_of_foo), ir
+ ),
+ )
+ self.assertEqual(
+ ir.module[0], ir_util.find_parent_object(canonical_name_of_foo, ir)
+ )
+ self.assertEqual(
+ ir.module[0], ir_util.find_parent_object(("test.emb", "Foo"), ir)
+ )
- # Test that find_parent_object works with objects other than structs.
- self.assertEqual(ir.module[0],
- ir_util.find_parent_object(("test.emb", "Bar"), ir))
- self.assertEqual(ir.module[1], ir_util.find_parent_object(("", "UInt"), ir))
- self.assertEqual(ir.module[0].type[0],
- ir_util.find_parent_object(("test.emb", "Foo", "field"),
- ir))
- self.assertEqual(ir.module[0].type[1],
- ir_util.find_parent_object(("test.emb", "Bar", "QUX"), ir))
+ # Test that find_parent_object works with objects other than structs.
+ self.assertEqual(
+ ir.module[0], ir_util.find_parent_object(("test.emb", "Bar"), ir)
+ )
+ self.assertEqual(ir.module[1], ir_util.find_parent_object(("", "UInt"), ir))
+ self.assertEqual(
+ ir.module[0].type[0],
+ ir_util.find_parent_object(("test.emb", "Foo", "field"), ir),
+ )
+ self.assertEqual(
+ ir.module[0].type[1],
+ ir_util.find_parent_object(("test.emb", "Bar", "QUX"), ir),
+ )
- def test_hashable_form_of_reference(self):
- self.assertEqual(
- ("t.emb", "Foo", "Bar"),
- ir_util.hashable_form_of_reference(ir_data.Reference(
- canonical_name=ir_data.CanonicalName(module_file="t.emb",
- object_path=["Foo", "Bar"]))))
- self.assertEqual(
- ("t.emb", "Foo", "Bar"),
- ir_util.hashable_form_of_reference(ir_data.NameDefinition(
- canonical_name=ir_data.CanonicalName(module_file="t.emb",
- object_path=["Foo", "Bar"]))))
+ def test_hashable_form_of_reference(self):
+ self.assertEqual(
+ ("t.emb", "Foo", "Bar"),
+ ir_util.hashable_form_of_reference(
+ ir_data.Reference(
+ canonical_name=ir_data.CanonicalName(
+ module_file="t.emb", object_path=["Foo", "Bar"]
+ )
+ )
+ ),
+ )
+ self.assertEqual(
+ ("t.emb", "Foo", "Bar"),
+ ir_util.hashable_form_of_reference(
+ ir_data.NameDefinition(
+ canonical_name=ir_data.CanonicalName(
+ module_file="t.emb", object_path=["Foo", "Bar"]
+ )
+ )
+ ),
+ )
- def test_get_base_type(self):
- array_type_ir = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
- """{
+ def test_get_base_type(self):
+ array_type_ir = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.Type,
+ """{
"array_type": {
"element_count": { "constant": { "value": "20" } },
"base_type": {
@@ -583,16 +790,19 @@
},
"source_location": { "start": { "line": 3 } }
}
- }""")
- base_type_ir = array_type_ir.array_type.base_type.array_type.base_type
- self.assertEqual(base_type_ir, ir_util.get_base_type(array_type_ir))
- self.assertEqual(base_type_ir, ir_util.get_base_type(
- array_type_ir.array_type.base_type))
- self.assertEqual(base_type_ir, ir_util.get_base_type(base_type_ir))
+ }""",
+ )
+ base_type_ir = array_type_ir.array_type.base_type.array_type.base_type
+ self.assertEqual(base_type_ir, ir_util.get_base_type(array_type_ir))
+ self.assertEqual(
+ base_type_ir, ir_util.get_base_type(array_type_ir.array_type.base_type)
+ )
+ self.assertEqual(base_type_ir, ir_util.get_base_type(base_type_ir))
- def test_size_of_type_in_bits(self):
- ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr,
- """{
+ def test_size_of_type_in_bits(self):
+ ir = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.EmbossIr,
+ """{
"module": [{
"type": [{
"name": {
@@ -637,20 +847,24 @@
}],
"source_file_name": ""
}]
- }""")
+ }""",
+ )
- fixed_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
- """{
+ fixed_size_type = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.Type,
+ """{
"atomic_type": {
"reference": {
"canonical_name": { "module_file": "", "object_path": ["Byte"] }
}
}
- }""")
- self.assertEqual(8, ir_util.fixed_size_of_type_in_bits(fixed_size_type, ir))
+ }""",
+ )
+ self.assertEqual(8, ir_util.fixed_size_of_type_in_bits(fixed_size_type, ir))
- explicit_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
- """{
+ explicit_size_type = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.Type,
+ """{
"atomic_type": {
"reference": {
"canonical_name": { "module_file": "", "object_path": ["UInt"] }
@@ -662,12 +876,13 @@
"integer": { "modular_value": "32", "modulus": "infinity" }
}
}
- }""")
- self.assertEqual(32,
- ir_util.fixed_size_of_type_in_bits(explicit_size_type, ir))
+ }""",
+ )
+ self.assertEqual(32, ir_util.fixed_size_of_type_in_bits(explicit_size_type, ir))
- fixed_size_array = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
- """{
+ fixed_size_array = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.Type,
+ """{
"array_type": {
"base_type": {
"atomic_type": {
@@ -683,12 +898,13 @@
}
}
}
- }""")
- self.assertEqual(40,
- ir_util.fixed_size_of_type_in_bits(fixed_size_array, ir))
+ }""",
+ )
+ self.assertEqual(40, ir_util.fixed_size_of_type_in_bits(fixed_size_array, ir))
- fixed_size_2d_array = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
- """{
+ fixed_size_2d_array = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.Type,
+ """{
"array_type": {
"base_type": {
"array_type": {
@@ -717,12 +933,15 @@
}
}
}
- }""")
- self.assertEqual(
- 80, ir_util.fixed_size_of_type_in_bits(fixed_size_2d_array, ir))
+ }""",
+ )
+ self.assertEqual(
+ 80, ir_util.fixed_size_of_type_in_bits(fixed_size_2d_array, ir)
+ )
- automatic_size_array = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
- """{
+ automatic_size_array = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.Type,
+ """{
"array_type": {
"base_type": {
"array_type": {
@@ -746,23 +965,25 @@
},
"automatic": { }
}
- }""")
- self.assertIsNone(
- ir_util.fixed_size_of_type_in_bits(automatic_size_array, ir))
+ }""",
+ )
+ self.assertIsNone(ir_util.fixed_size_of_type_in_bits(automatic_size_array, ir))
- variable_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
- """{
+ variable_size_type = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.Type,
+ """{
"atomic_type": {
"reference": {
"canonical_name": { "module_file": "", "object_path": ["UInt"] }
}
}
- }""")
- self.assertIsNone(
- ir_util.fixed_size_of_type_in_bits(variable_size_type, ir))
+ }""",
+ )
+ self.assertIsNone(ir_util.fixed_size_of_type_in_bits(variable_size_type, ir))
- no_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
- """{
+ no_size_type = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.Type,
+ """{
"atomic_type": {
"reference": {
"canonical_name": {
@@ -771,26 +992,35 @@
}
}
}
- }""")
- self.assertIsNone(ir_util.fixed_size_of_type_in_bits(no_size_type, ir))
+ }""",
+ )
+ self.assertIsNone(ir_util.fixed_size_of_type_in_bits(no_size_type, ir))
- def test_field_is_virtual(self):
- self.assertTrue(ir_util.field_is_virtual(ir_data.Field()))
+ def test_field_is_virtual(self):
+ self.assertTrue(ir_util.field_is_virtual(ir_data.Field()))
- def test_field_is_not_virtual(self):
- self.assertFalse(ir_util.field_is_virtual(
- ir_data.Field(location=ir_data.FieldLocation())))
+ def test_field_is_not_virtual(self):
+ self.assertFalse(
+ ir_util.field_is_virtual(ir_data.Field(location=ir_data.FieldLocation()))
+ )
- def test_field_is_read_only(self):
- self.assertTrue(ir_util.field_is_read_only(ir_data.Field(
- write_method=ir_data.WriteMethod(read_only=True))))
+ def test_field_is_read_only(self):
+ self.assertTrue(
+ ir_util.field_is_read_only(
+ ir_data.Field(write_method=ir_data.WriteMethod(read_only=True))
+ )
+ )
- def test_field_is_not_read_only(self):
- self.assertFalse(ir_util.field_is_read_only(
- ir_data.Field(location=ir_data.FieldLocation())))
- self.assertFalse(ir_util.field_is_read_only(ir_data.Field(
- write_method=ir_data.WriteMethod())))
+ def test_field_is_not_read_only(self):
+ self.assertFalse(
+ ir_util.field_is_read_only(ir_data.Field(location=ir_data.FieldLocation()))
+ )
+ self.assertFalse(
+ ir_util.field_is_read_only(
+ ir_data.Field(write_method=ir_data.WriteMethod())
+ )
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/util/name_conversion.py b/compiler/util/name_conversion.py
index 3f32809..b13d667 100644
--- a/compiler/util/name_conversion.py
+++ b/compiler/util/name_conversion.py
@@ -18,10 +18,10 @@
class Case(str, Enum):
- SNAKE = "snake_case"
- SHOUTY = "SHOUTY_CASE"
- CAMEL = "CamelCase"
- K_CAMEL = "kCamelCase"
+ SNAKE = "snake_case"
+ SHOUTY = "SHOUTY_CASE"
+ CAMEL = "CamelCase"
+ K_CAMEL = "kCamelCase"
# Map of (from, to) cases to their conversion function. Initially only contains
@@ -31,41 +31,42 @@
def _case_conversion(case_from, case_to):
- """Decorator to dynamically dispatch case conversions at runtime."""
- def _func(f):
- _case_conversions[case_from, case_to] = f
- return f
+ """Decorator to dynamically dispatch case conversions at runtime."""
- return _func
+ def _func(f):
+ _case_conversions[case_from, case_to] = f
+ return f
+
+ return _func
@_case_conversion(Case.SNAKE, Case.CAMEL)
@_case_conversion(Case.SHOUTY, Case.CAMEL)
def snake_to_camel(name):
- """Convert from snake_case to CamelCase. Also works from SHOUTY_CASE."""
- return "".join(word.capitalize() for word in name.split("_"))
+ """Convert from snake_case to CamelCase. Also works from SHOUTY_CASE."""
+ return "".join(word.capitalize() for word in name.split("_"))
@_case_conversion(Case.CAMEL, Case.K_CAMEL)
def camel_to_k_camel(name):
- """Convert from CamelCase to kCamelCase."""
- return "k" + name
+ """Convert from CamelCase to kCamelCase."""
+ return "k" + name
@_case_conversion(Case.SNAKE, Case.K_CAMEL)
@_case_conversion(Case.SHOUTY, Case.K_CAMEL)
def snake_to_k_camel(name):
- """Converts from snake_case to kCamelCase. Also works from SHOUTY_CASE."""
- return camel_to_k_camel(snake_to_camel(name))
+ """Converts from snake_case to kCamelCase. Also works from SHOUTY_CASE."""
+ return camel_to_k_camel(snake_to_camel(name))
def convert_case(case_from, case_to, value):
- """Converts cases based on runtime case values.
+ """Converts cases based on runtime case values.
- Note: Cases can be strings or enum values."""
- return _case_conversions[case_from, case_to](value)
+ Note: Cases can be strings or enum values."""
+ return _case_conversions[case_from, case_to](value)
def is_case_conversion_supported(case_from, case_to):
- """Determines if a case conversion would be supported"""
- return (case_from, case_to) in _case_conversions
+ """Determines if a case conversion would be supported"""
+ return (case_from, case_to) in _case_conversions
diff --git a/compiler/util/name_conversion_test.py b/compiler/util/name_conversion_test.py
index 9e41ecc..91309ea 100644
--- a/compiler/util/name_conversion_test.py
+++ b/compiler/util/name_conversion_test.py
@@ -20,70 +20,77 @@
class NameConversionTest(unittest.TestCase):
- def test_snake_to_camel(self):
- self.assertEqual("", name_conversion.snake_to_camel(""))
- self.assertEqual("Abc", name_conversion.snake_to_camel("abc"))
- self.assertEqual("AbcDef", name_conversion.snake_to_camel("abc_def"))
- self.assertEqual("AbcDef89", name_conversion.snake_to_camel("abc_def89"))
- self.assertEqual("AbcDef89", name_conversion.snake_to_camel("abc_def_89"))
- self.assertEqual("Abc89Def", name_conversion.snake_to_camel("abc_89_def"))
- self.assertEqual("Abc89def", name_conversion.snake_to_camel("abc_89def"))
+ def test_snake_to_camel(self):
+ self.assertEqual("", name_conversion.snake_to_camel(""))
+ self.assertEqual("Abc", name_conversion.snake_to_camel("abc"))
+ self.assertEqual("AbcDef", name_conversion.snake_to_camel("abc_def"))
+ self.assertEqual("AbcDef89", name_conversion.snake_to_camel("abc_def89"))
+ self.assertEqual("AbcDef89", name_conversion.snake_to_camel("abc_def_89"))
+ self.assertEqual("Abc89Def", name_conversion.snake_to_camel("abc_89_def"))
+ self.assertEqual("Abc89def", name_conversion.snake_to_camel("abc_89def"))
- def test_shouty_to_camel(self):
- self.assertEqual("Abc", name_conversion.snake_to_camel("ABC"))
- self.assertEqual("AbcDef", name_conversion.snake_to_camel("ABC_DEF"))
- self.assertEqual("AbcDef89", name_conversion.snake_to_camel("ABC_DEF89"))
- self.assertEqual("AbcDef89", name_conversion.snake_to_camel("ABC_DEF_89"))
- self.assertEqual("Abc89Def", name_conversion.snake_to_camel("ABC_89_DEF"))
- self.assertEqual("Abc89def", name_conversion.snake_to_camel("ABC_89DEF"))
+ def test_shouty_to_camel(self):
+ self.assertEqual("Abc", name_conversion.snake_to_camel("ABC"))
+ self.assertEqual("AbcDef", name_conversion.snake_to_camel("ABC_DEF"))
+ self.assertEqual("AbcDef89", name_conversion.snake_to_camel("ABC_DEF89"))
+ self.assertEqual("AbcDef89", name_conversion.snake_to_camel("ABC_DEF_89"))
+ self.assertEqual("Abc89Def", name_conversion.snake_to_camel("ABC_89_DEF"))
+ self.assertEqual("Abc89def", name_conversion.snake_to_camel("ABC_89DEF"))
- def test_camel_to_k_camel(self):
- self.assertEqual("kFoo", name_conversion.camel_to_k_camel("Foo"))
- self.assertEqual("kFooBar", name_conversion.camel_to_k_camel("FooBar"))
- self.assertEqual("kAbc123", name_conversion.camel_to_k_camel("Abc123"))
+ def test_camel_to_k_camel(self):
+ self.assertEqual("kFoo", name_conversion.camel_to_k_camel("Foo"))
+ self.assertEqual("kFooBar", name_conversion.camel_to_k_camel("FooBar"))
+ self.assertEqual("kAbc123", name_conversion.camel_to_k_camel("Abc123"))
- def test_snake_to_k_camel(self):
- self.assertEqual("kAbc", name_conversion.snake_to_k_camel("abc"))
- self.assertEqual("kAbcDef", name_conversion.snake_to_k_camel("abc_def"))
- self.assertEqual("kAbcDef89",
- name_conversion.snake_to_k_camel("abc_def89"))
- self.assertEqual("kAbcDef89",
- name_conversion.snake_to_k_camel("abc_def_89"))
- self.assertEqual("kAbc89Def",
- name_conversion.snake_to_k_camel("abc_89_def"))
- self.assertEqual("kAbc89def",
- name_conversion.snake_to_k_camel("abc_89def"))
+ def test_snake_to_k_camel(self):
+ self.assertEqual("kAbc", name_conversion.snake_to_k_camel("abc"))
+ self.assertEqual("kAbcDef", name_conversion.snake_to_k_camel("abc_def"))
+ self.assertEqual("kAbcDef89", name_conversion.snake_to_k_camel("abc_def89"))
+ self.assertEqual("kAbcDef89", name_conversion.snake_to_k_camel("abc_def_89"))
+ self.assertEqual("kAbc89Def", name_conversion.snake_to_k_camel("abc_89_def"))
+ self.assertEqual("kAbc89def", name_conversion.snake_to_k_camel("abc_89def"))
- def test_shouty_to_k_camel(self):
- self.assertEqual("kAbc", name_conversion.snake_to_k_camel("ABC"))
- self.assertEqual("kAbcDef", name_conversion.snake_to_k_camel("ABC_DEF"))
- self.assertEqual("kAbcDef89",
- name_conversion.snake_to_k_camel("ABC_DEF89"))
- self.assertEqual("kAbcDef89",
- name_conversion.snake_to_k_camel("ABC_DEF_89"))
- self.assertEqual("kAbc89Def",
- name_conversion.snake_to_k_camel("ABC_89_DEF"))
- self.assertEqual("kAbc89def",
- name_conversion.snake_to_k_camel("ABC_89DEF"))
+ def test_shouty_to_k_camel(self):
+ self.assertEqual("kAbc", name_conversion.snake_to_k_camel("ABC"))
+ self.assertEqual("kAbcDef", name_conversion.snake_to_k_camel("ABC_DEF"))
+ self.assertEqual("kAbcDef89", name_conversion.snake_to_k_camel("ABC_DEF89"))
+ self.assertEqual("kAbcDef89", name_conversion.snake_to_k_camel("ABC_DEF_89"))
+ self.assertEqual("kAbc89Def", name_conversion.snake_to_k_camel("ABC_89_DEF"))
+ self.assertEqual("kAbc89def", name_conversion.snake_to_k_camel("ABC_89DEF"))
- def test_convert_case(self):
- self.assertEqual("foo_bar_123", name_conversion.convert_case(
- "snake_case", "snake_case", "foo_bar_123"))
- self.assertEqual("FOO_BAR_123", name_conversion.convert_case(
- "SHOUTY_CASE", "SHOUTY_CASE", "FOO_BAR_123"))
- self.assertEqual("kFooBar123", name_conversion.convert_case(
- "kCamelCase", "kCamelCase", "kFooBar123"))
- self.assertEqual("FooBar123", name_conversion.convert_case(
- "CamelCase", "CamelCase", "FooBar123"))
- self.assertEqual("kAbcDef", name_conversion.convert_case(
- "snake_case", "kCamelCase", "abc_def"))
- self.assertEqual("AbcDef", name_conversion.convert_case(
- "snake_case", "CamelCase", "abc_def"))
- self.assertEqual("kAbcDef", name_conversion.convert_case(
- "SHOUTY_CASE", "kCamelCase", "ABC_DEF"))
- self.assertEqual("AbcDef", name_conversion.convert_case(
- "SHOUTY_CASE", "CamelCase", "ABC_DEF"))
+ def test_convert_case(self):
+ self.assertEqual(
+ "foo_bar_123",
+ name_conversion.convert_case("snake_case", "snake_case", "foo_bar_123"),
+ )
+ self.assertEqual(
+ "FOO_BAR_123",
+ name_conversion.convert_case("SHOUTY_CASE", "SHOUTY_CASE", "FOO_BAR_123"),
+ )
+ self.assertEqual(
+ "kFooBar123",
+ name_conversion.convert_case("kCamelCase", "kCamelCase", "kFooBar123"),
+ )
+ self.assertEqual(
+ "FooBar123",
+ name_conversion.convert_case("CamelCase", "CamelCase", "FooBar123"),
+ )
+ self.assertEqual(
+ "kAbcDef",
+ name_conversion.convert_case("snake_case", "kCamelCase", "abc_def"),
+ )
+ self.assertEqual(
+ "AbcDef", name_conversion.convert_case("snake_case", "CamelCase", "abc_def")
+ )
+ self.assertEqual(
+ "kAbcDef",
+ name_conversion.convert_case("SHOUTY_CASE", "kCamelCase", "ABC_DEF"),
+ )
+ self.assertEqual(
+ "AbcDef",
+ name_conversion.convert_case("SHOUTY_CASE", "CamelCase", "ABC_DEF"),
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/util/parser_types.py b/compiler/util/parser_types.py
index 5b63ffa..fffe086 100644
--- a/compiler/util/parser_types.py
+++ b/compiler/util/parser_types.py
@@ -25,92 +25,97 @@
def _make_position(line, column):
- """Makes an ir_data.Position from line, column ints."""
- if not isinstance(line, int):
- raise ValueError("Bad line {!r}".format(line))
- elif not isinstance(column, int):
- raise ValueError("Bad column {!r}".format(column))
- return ir_data.Position(line=line, column=column)
+ """Makes an ir_data.Position from line, column ints."""
+ if not isinstance(line, int):
+ raise ValueError("Bad line {!r}".format(line))
+ elif not isinstance(column, int):
+ raise ValueError("Bad column {!r}".format(column))
+ return ir_data.Position(line=line, column=column)
def _parse_position(text):
- """Parses an ir_data.Position from "line:column" (e.g., "1:2")."""
- line, column = text.split(":")
- return _make_position(int(line), int(column))
+ """Parses an ir_data.Position from "line:column" (e.g., "1:2")."""
+ line, column = text.split(":")
+ return _make_position(int(line), int(column))
def format_position(position):
- """formats an ir_data.Position to "line:column" form."""
- return "{}:{}".format(position.line, position.column)
+ """formats an ir_data.Position to "line:column" form."""
+ return "{}:{}".format(position.line, position.column)
def make_location(start, end, is_synthetic=False):
- """Makes an ir_data.Location from (line, column) tuples or ir_data.Positions."""
- if isinstance(start, tuple):
- start = _make_position(*start)
- if isinstance(end, tuple):
- end = _make_position(*end)
- if not isinstance(start, ir_data.Position):
- raise ValueError("Bad start {!r}".format(start))
- elif not isinstance(end, ir_data.Position):
- raise ValueError("Bad end {!r}".format(end))
- elif start.line > end.line or (
- start.line == end.line and start.column > end.column):
- raise ValueError("Start {} is after end {}".format(format_position(start),
- format_position(end)))
- return ir_data.Location(start=start, end=end, is_synthetic=is_synthetic)
+ """Makes an ir_data.Location from (line, column) tuples or ir_data.Positions."""
+ if isinstance(start, tuple):
+ start = _make_position(*start)
+ if isinstance(end, tuple):
+ end = _make_position(*end)
+ if not isinstance(start, ir_data.Position):
+ raise ValueError("Bad start {!r}".format(start))
+ elif not isinstance(end, ir_data.Position):
+ raise ValueError("Bad end {!r}".format(end))
+ elif start.line > end.line or (
+ start.line == end.line and start.column > end.column
+ ):
+ raise ValueError(
+ "Start {} is after end {}".format(
+ format_position(start), format_position(end)
+ )
+ )
+ return ir_data.Location(start=start, end=end, is_synthetic=is_synthetic)
def format_location(location):
- """Formats an ir_data.Location in format "1:2-3:4" ("start-end")."""
- return "{}-{}".format(format_position(location.start),
- format_position(location.end))
+ """Formats an ir_data.Location in format "1:2-3:4" ("start-end")."""
+ return "{}-{}".format(
+ format_position(location.start), format_position(location.end)
+ )
def parse_location(text):
- """Parses an ir_data.Location from format "1:2-3:4" ("start-end")."""
- start, end = text.split("-")
- return make_location(_parse_position(start), _parse_position(end))
+ """Parses an ir_data.Location from format "1:2-3:4" ("start-end")."""
+ start, end = text.split("-")
+ return make_location(_parse_position(start), _parse_position(end))
-class Token(
- collections.namedtuple("Token", ["symbol", "text", "source_location"])):
- """A Token is a chunk of text from a source file, and a classification.
+class Token(collections.namedtuple("Token", ["symbol", "text", "source_location"])):
+ """A Token is a chunk of text from a source file, and a classification.
- Attributes:
- symbol: The name of this token ("Indent", "SnakeWord", etc.)
- text: The original text ("1234", "some_name", etc.)
- source_location: Where this token came from in the original source file.
- """
+ Attributes:
+ symbol: The name of this token ("Indent", "SnakeWord", etc.)
+ text: The original text ("1234", "some_name", etc.)
+ source_location: Where this token came from in the original source file.
+ """
- def __str__(self):
- return "{} {} {}".format(self.symbol, repr(self.text),
- format_location(self.source_location))
+ def __str__(self):
+ return "{} {} {}".format(
+ self.symbol, repr(self.text), format_location(self.source_location)
+ )
class Production(collections.namedtuple("Production", ["lhs", "rhs"])):
- """A Production is a simple production from a context-free grammar.
+ """A Production is a simple production from a context-free grammar.
- A Production takes the form:
+ A Production takes the form:
- nonterminal -> symbol*
+ nonterminal -> symbol*
- where "nonterminal" is an implicitly non-terminal symbol in the language,
- and "symbol*" is zero or more terminal or non-terminal symbols which form the
- non-terminal on the left.
+ where "nonterminal" is an implicitly non-terminal symbol in the language,
+ and "symbol*" is zero or more terminal or non-terminal symbols which form the
+ non-terminal on the left.
- Attributes:
- lhs: The non-terminal symbol on the left-hand-side of the production.
- rhs: The sequence of symbols on the right-hand-side of the production.
- """
+ Attributes:
+ lhs: The non-terminal symbol on the left-hand-side of the production.
+ rhs: The sequence of symbols on the right-hand-side of the production.
+ """
- def __str__(self):
- return str(self.lhs) + " -> " + " ".join([str(r) for r in self.rhs])
+ def __str__(self):
+ return str(self.lhs) + " -> " + " ".join([str(r) for r in self.rhs])
- @staticmethod
- def parse(production_text):
- """Parses a Production from a "symbol -> symbol symbol symbol" string."""
- words = production_text.split()
- if words[1] != "->":
- raise SyntaxError
- return Production(words[0], tuple(words[2:]))
+ @staticmethod
+ def parse(production_text):
+ """Parses a Production from a "symbol -> symbol symbol symbol" string."""
+ words = production_text.split()
+ if words[1] != "->":
+ raise SyntaxError
+ return Production(words[0], tuple(words[2:]))
diff --git a/compiler/util/parser_types_test.py b/compiler/util/parser_types_test.py
index 5e6fddf..097ece4 100644
--- a/compiler/util/parser_types_test.py
+++ b/compiler/util/parser_types_test.py
@@ -20,114 +20,141 @@
class PositionTest(unittest.TestCase):
- """Tests for Position-related functions in parser_types."""
+ """Tests for Position-related functions in parser_types."""
- def test_format_position(self):
- self.assertEqual(
- "1:2", parser_types.format_position(ir_data.Position(line=1, column=2)))
+ def test_format_position(self):
+ self.assertEqual(
+ "1:2", parser_types.format_position(ir_data.Position(line=1, column=2))
+ )
class LocationTest(unittest.TestCase):
- """Tests for Location-related functions in parser_types."""
+ """Tests for Location-related functions in parser_types."""
- def test_make_location(self):
- self.assertEqual(ir_data.Location(start=ir_data.Position(line=1,
- column=2),
- end=ir_data.Position(line=3,
- column=4),
- is_synthetic=False),
- parser_types.make_location((1, 2), (3, 4)))
- self.assertEqual(
- ir_data.Location(start=ir_data.Position(line=1,
- column=2),
- end=ir_data.Position(line=3,
- column=4),
- is_synthetic=False),
- parser_types.make_location(ir_data.Position(line=1,
- column=2),
- ir_data.Position(line=3,
- column=4)))
+ def test_make_location(self):
+ self.assertEqual(
+ ir_data.Location(
+ start=ir_data.Position(line=1, column=2),
+ end=ir_data.Position(line=3, column=4),
+ is_synthetic=False,
+ ),
+ parser_types.make_location((1, 2), (3, 4)),
+ )
+ self.assertEqual(
+ ir_data.Location(
+ start=ir_data.Position(line=1, column=2),
+ end=ir_data.Position(line=3, column=4),
+ is_synthetic=False,
+ ),
+ parser_types.make_location(
+ ir_data.Position(line=1, column=2), ir_data.Position(line=3, column=4)
+ ),
+ )
- def test_make_synthetic_location(self):
- self.assertEqual(
- ir_data.Location(start=ir_data.Position(line=1, column=2),
- end=ir_data.Position(line=3, column=4),
- is_synthetic=True),
- parser_types.make_location((1, 2), (3, 4), True))
- self.assertEqual(
- ir_data.Location(start=ir_data.Position(line=1, column=2),
- end=ir_data.Position(line=3, column=4),
- is_synthetic=True),
- parser_types.make_location(ir_data.Position(line=1, column=2),
- ir_data.Position(line=3, column=4),
- True))
+ def test_make_synthetic_location(self):
+ self.assertEqual(
+ ir_data.Location(
+ start=ir_data.Position(line=1, column=2),
+ end=ir_data.Position(line=3, column=4),
+ is_synthetic=True,
+ ),
+ parser_types.make_location((1, 2), (3, 4), True),
+ )
+ self.assertEqual(
+ ir_data.Location(
+ start=ir_data.Position(line=1, column=2),
+ end=ir_data.Position(line=3, column=4),
+ is_synthetic=True,
+ ),
+ parser_types.make_location(
+ ir_data.Position(line=1, column=2),
+ ir_data.Position(line=3, column=4),
+ True,
+ ),
+ )
- def test_make_location_type_checks(self):
- self.assertRaises(ValueError, parser_types.make_location, [1, 2], (1, 2))
- self.assertRaises(ValueError, parser_types.make_location, (1, 2), [1, 2])
+ def test_make_location_type_checks(self):
+ self.assertRaises(ValueError, parser_types.make_location, [1, 2], (1, 2))
+ self.assertRaises(ValueError, parser_types.make_location, (1, 2), [1, 2])
- def test_make_location_logic_checks(self):
- self.assertRaises(ValueError, parser_types.make_location, (3, 4), (1, 2))
- self.assertRaises(ValueError, parser_types.make_location, (1, 3), (1, 2))
- self.assertTrue(parser_types.make_location((1, 2), (1, 2)))
+ def test_make_location_logic_checks(self):
+ self.assertRaises(ValueError, parser_types.make_location, (3, 4), (1, 2))
+ self.assertRaises(ValueError, parser_types.make_location, (1, 3), (1, 2))
+ self.assertTrue(parser_types.make_location((1, 2), (1, 2)))
- def test_format_location(self):
- self.assertEqual("1:2-3:4",
- parser_types.format_location(parser_types.make_location(
- (1, 2), (3, 4))))
+ def test_format_location(self):
+ self.assertEqual(
+ "1:2-3:4",
+ parser_types.format_location(parser_types.make_location((1, 2), (3, 4))),
+ )
- def test_parse_location(self):
- self.assertEqual(parser_types.make_location((1, 2), (3, 4)),
- parser_types.parse_location("1:2-3:4"))
- self.assertEqual(parser_types.make_location((1, 2), (3, 4)),
- parser_types.parse_location(" 1 : 2 - 3 : 4 "))
+ def test_parse_location(self):
+ self.assertEqual(
+ parser_types.make_location((1, 2), (3, 4)),
+ parser_types.parse_location("1:2-3:4"),
+ )
+ self.assertEqual(
+ parser_types.make_location((1, 2), (3, 4)),
+ parser_types.parse_location(" 1 : 2 - 3 : 4 "),
+ )
class TokenTest(unittest.TestCase):
- """Tests for parser_types.Token."""
+ """Tests for parser_types.Token."""
- def test_str(self):
- self.assertEqual("FOO 'bar' 1:2-3:4", str(parser_types.Token(
- "FOO", "bar", parser_types.make_location((1, 2), (3, 4)))))
+ def test_str(self):
+ self.assertEqual(
+ "FOO 'bar' 1:2-3:4",
+ str(
+ parser_types.Token(
+ "FOO", "bar", parser_types.make_location((1, 2), (3, 4))
+ )
+ ),
+ )
class ProductionTest(unittest.TestCase):
- """Tests for parser_types.Production."""
+ """Tests for parser_types.Production."""
- def test_parse(self):
- self.assertEqual(parser_types.Production(lhs="A",
- rhs=("B", "C")),
- parser_types.Production.parse("A -> B C"))
- self.assertEqual(parser_types.Production(lhs="A",
- rhs=("B",)),
- parser_types.Production.parse("A -> B"))
- self.assertEqual(parser_types.Production(lhs="A",
- rhs=("B", "C")),
- parser_types.Production.parse(" A -> B C "))
- self.assertEqual(parser_types.Production(lhs="A",
- rhs=tuple()),
- parser_types.Production.parse("A ->"))
- self.assertEqual(parser_types.Production(lhs="A",
- rhs=tuple()),
- parser_types.Production.parse("A -> "))
- self.assertEqual(parser_types.Production(lhs="FOO",
- rhs=('"B"', "x*")),
- parser_types.Production.parse('FOO -> "B" x*'))
- self.assertRaises(SyntaxError, parser_types.Production.parse, "F-> A B")
- self.assertRaises(SyntaxError, parser_types.Production.parse, "F B -> A B")
- self.assertRaises(SyntaxError, parser_types.Production.parse, "-> A B")
+ def test_parse(self):
+ self.assertEqual(
+ parser_types.Production(lhs="A", rhs=("B", "C")),
+ parser_types.Production.parse("A -> B C"),
+ )
+ self.assertEqual(
+ parser_types.Production(lhs="A", rhs=("B",)),
+ parser_types.Production.parse("A -> B"),
+ )
+ self.assertEqual(
+ parser_types.Production(lhs="A", rhs=("B", "C")),
+ parser_types.Production.parse(" A -> B C "),
+ )
+ self.assertEqual(
+ parser_types.Production(lhs="A", rhs=tuple()),
+ parser_types.Production.parse("A ->"),
+ )
+ self.assertEqual(
+ parser_types.Production(lhs="A", rhs=tuple()),
+ parser_types.Production.parse("A -> "),
+ )
+ self.assertEqual(
+ parser_types.Production(lhs="FOO", rhs=('"B"', "x*")),
+ parser_types.Production.parse('FOO -> "B" x*'),
+ )
+ self.assertRaises(SyntaxError, parser_types.Production.parse, "F-> A B")
+ self.assertRaises(SyntaxError, parser_types.Production.parse, "F B -> A B")
+ self.assertRaises(SyntaxError, parser_types.Production.parse, "-> A B")
- def test_str(self):
- self.assertEqual(str(parser_types.Production(lhs="A",
- rhs=("B", "C"))), "A -> B C")
- self.assertEqual(str(parser_types.Production(lhs="A",
- rhs=("B",))), "A -> B")
- self.assertEqual(str(parser_types.Production(lhs="A",
- rhs=tuple())), "A -> ")
- self.assertEqual(str(parser_types.Production(lhs="FOO",
- rhs=('"B"', "x*"))),
- 'FOO -> "B" x*')
+ def test_str(self):
+ self.assertEqual(
+ str(parser_types.Production(lhs="A", rhs=("B", "C"))), "A -> B C"
+ )
+ self.assertEqual(str(parser_types.Production(lhs="A", rhs=("B",))), "A -> B")
+ self.assertEqual(str(parser_types.Production(lhs="A", rhs=tuple())), "A -> ")
+ self.assertEqual(
+ str(parser_types.Production(lhs="FOO", rhs=('"B"', "x*"))), 'FOO -> "B" x*'
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/util/resources.py b/compiler/util/resources.py
index 606b1d9..5be66af 100644
--- a/compiler/util/resources.py
+++ b/compiler/util/resources.py
@@ -16,8 +16,10 @@
import importlib.resources
+
def load(package, file, encoding="utf-8"):
- """Returns the contents of `file` from the Python package loader."""
- with importlib.resources.files(
- package).joinpath(file).open("r", encoding=encoding) as f:
- return f.read()
+ """Returns the contents of `file` from the Python package loader."""
+ with importlib.resources.files(package).joinpath(file).open(
+ "r", encoding=encoding
+ ) as f:
+ return f.read()
diff --git a/compiler/util/simple_memoizer.py b/compiler/util/simple_memoizer.py
index 15a1719..fbd46c4 100644
--- a/compiler/util/simple_memoizer.py
+++ b/compiler/util/simple_memoizer.py
@@ -16,54 +16,54 @@
def memoize(f):
- """Memoizes f.
+ """Memoizes f.
- The @memoize decorator returns a function which caches the results of f, and
- returns directly from the cache instead of calling f when it is called again
- with the same arguments.
+ The @memoize decorator returns a function which caches the results of f, and
+ returns directly from the cache instead of calling f when it is called again
+ with the same arguments.
- Memoization has some caveats:
+ Memoization has some caveats:
- Most importantly, the decorated function will not be called every time the
- function is called. If the memoized function `f` performs I/O or relies on
- or changes global state, it may not work correctly when memoized.
+ Most importantly, the decorated function will not be called every time the
+ function is called. If the memoized function `f` performs I/O or relies on
+ or changes global state, it may not work correctly when memoized.
- This memoizer only works for functions taking positional arguments. It does
- not handle keywork arguments.
+ This memoizer only works for functions taking positional arguments. It does
+ not handle keywork arguments.
- This memoizer only works for hashable arguments -- tuples, ints, etc. It does
- not work on most iterables.
+ This memoizer only works for hashable arguments -- tuples, ints, etc. It does
+ not work on most iterables.
- This memoizer returns a function whose __name__ and argument list may differ
- from the memoized function under reflection.
+ This memoizer returns a function whose __name__ and argument list may differ
+ from the memoized function under reflection.
- This memoizer never evicts anything from its cache, so its memory usage can
- grow indefinitely.
+ This memoizer never evicts anything from its cache, so its memory usage can
+ grow indefinitely.
- Depending on the workload and speed of `f`, the memoized `f` can be slower
- than unadorned `f`; it is important to use profiling before and after
- memoization.
+ Depending on the workload and speed of `f`, the memoized `f` can be slower
+ than unadorned `f`; it is important to use profiling before and after
+ memoization.
- Usage:
- @memoize
- def function(arg, arg2, arg3):
- ...
+ Usage:
+ @memoize
+ def function(arg, arg2, arg3):
+ ...
- Arguments:
- f: The function to memoize.
+ Arguments:
+ f: The function to memoize.
- Returns:
- A function which acts like f, but faster when called repeatedly with the
- same arguments.
- """
- cache = {}
+ Returns:
+ A function which acts like f, but faster when called repeatedly with the
+ same arguments.
+ """
+ cache = {}
- def _memoized(*args):
- assert all(arg.__hash__ for arg in args), (
- "Arguments to memoized function {} must be hashable.".format(
- f.__name__))
- if args not in cache:
- cache[args] = f(*args)
- return cache[args]
+ def _memoized(*args):
+ assert all(
+ arg.__hash__ for arg in args
+ ), "Arguments to memoized function {} must be hashable.".format(f.__name__)
+ if args not in cache:
+ cache[args] = f(*args)
+ return cache[args]
- return _memoized
+ return _memoized
diff --git a/compiler/util/simple_memoizer_test.py b/compiler/util/simple_memoizer_test.py
index cfc35c7..855b3f4 100644
--- a/compiler/util/simple_memoizer_test.py
+++ b/compiler/util/simple_memoizer_test.py
@@ -20,54 +20,55 @@
class SimpleMemoizerTest(unittest.TestCase):
- def test_memoized_function_returns_same_values(self):
- @simple_memoizer.memoize
- def add_one(n):
- return n + 1
+ def test_memoized_function_returns_same_values(self):
+ @simple_memoizer.memoize
+ def add_one(n):
+ return n + 1
- for i in range(100):
- self.assertEqual(i + 1, add_one(i))
+ for i in range(100):
+ self.assertEqual(i + 1, add_one(i))
- def test_memoized_function_is_only_called_once(self):
- arguments = []
+ def test_memoized_function_is_only_called_once(self):
+ arguments = []
- @simple_memoizer.memoize
- def add_one_and_add_argument_to_list(n):
- arguments.append(n)
- return n + 1
+ @simple_memoizer.memoize
+ def add_one_and_add_argument_to_list(n):
+ arguments.append(n)
+ return n + 1
- self.assertEqual(1, add_one_and_add_argument_to_list(0))
- self.assertEqual([0], arguments)
- self.assertEqual(1, add_one_and_add_argument_to_list(0))
- self.assertEqual([0], arguments)
+ self.assertEqual(1, add_one_and_add_argument_to_list(0))
+ self.assertEqual([0], arguments)
+ self.assertEqual(1, add_one_and_add_argument_to_list(0))
+ self.assertEqual([0], arguments)
- def test_memoized_function_with_multiple_arguments(self):
- arguments = []
+ def test_memoized_function_with_multiple_arguments(self):
+ arguments = []
- @simple_memoizer.memoize
- def sum_arguments_and_add_arguments_to_list(n, m, o):
- arguments.append((n, m, o))
- return n + m + o
+ @simple_memoizer.memoize
+ def sum_arguments_and_add_arguments_to_list(n, m, o):
+ arguments.append((n, m, o))
+ return n + m + o
- self.assertEqual(3, sum_arguments_and_add_arguments_to_list(0, 1, 2))
- self.assertEqual([(0, 1, 2)], arguments)
- self.assertEqual(3, sum_arguments_and_add_arguments_to_list(0, 1, 2))
- self.assertEqual([(0, 1, 2)], arguments)
- self.assertEqual(3, sum_arguments_and_add_arguments_to_list(2, 1, 0))
- self.assertEqual([(0, 1, 2), (2, 1, 0)], arguments)
+ self.assertEqual(3, sum_arguments_and_add_arguments_to_list(0, 1, 2))
+ self.assertEqual([(0, 1, 2)], arguments)
+ self.assertEqual(3, sum_arguments_and_add_arguments_to_list(0, 1, 2))
+ self.assertEqual([(0, 1, 2)], arguments)
+ self.assertEqual(3, sum_arguments_and_add_arguments_to_list(2, 1, 0))
+ self.assertEqual([(0, 1, 2), (2, 1, 0)], arguments)
- def test_memoized_function_with_no_arguments(self):
- arguments = []
+ def test_memoized_function_with_no_arguments(self):
+ arguments = []
- @simple_memoizer.memoize
- def return_one_and_add_empty_tuple_to_list():
- arguments.append(())
- return 1
+ @simple_memoizer.memoize
+ def return_one_and_add_empty_tuple_to_list():
+ arguments.append(())
+ return 1
- self.assertEqual(1, return_one_and_add_empty_tuple_to_list())
- self.assertEqual([()], arguments)
- self.assertEqual(1, return_one_and_add_empty_tuple_to_list())
- self.assertEqual([()], arguments)
+ self.assertEqual(1, return_one_and_add_empty_tuple_to_list())
+ self.assertEqual([()], arguments)
+ self.assertEqual(1, return_one_and_add_empty_tuple_to_list())
+ self.assertEqual([()], arguments)
-if __name__ == '__main__':
- unittest.main()
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/compiler/util/test_util.py b/compiler/util/test_util.py
index 02ac9a3..d54af37 100644
--- a/compiler/util/test_util.py
+++ b/compiler/util/test_util.py
@@ -18,89 +18,92 @@
def proto_is_superset(proto, expected_values, path=""):
- """Returns true if every value in expected_values is set in proto.
+ """Returns true if every value in expected_values is set in proto.
- This is intended to be used in assertTrue in a unit test, like so:
+ This is intended to be used in assertTrue in a unit test, like so:
- self.assertTrue(*proto_is_superset(proto, expected))
+ self.assertTrue(*proto_is_superset(proto, expected))
- Arguments:
- proto: The proto to check.
- expected_values: The reference proto.
- path: The path to the elements being compared. Clients can generally leave
- this at default.
+ Arguments:
+ proto: The proto to check.
+ expected_values: The reference proto.
+ path: The path to the elements being compared. Clients can generally leave
+ this at default.
- Returns:
- A tuple; the first element is True if the fields set in proto are a strict
- superset of the fields set in expected_values. The second element is an
- informational string specifying the path of a value found in expected_values
- but not in proto.
+ Returns:
+ A tuple; the first element is True if the fields set in proto are a strict
+ superset of the fields set in expected_values. The second element is an
+ informational string specifying the path of a value found in expected_values
+ but not in proto.
- Every atomic field that is set in expected_values must be set to the same
- value in proto; every message field set in expected_values must have a
- matching field in proto, such that proto_is_superset(proto.field,
- expected_values.field) is true.
+ Every atomic field that is set in expected_values must be set to the same
+ value in proto; every message field set in expected_values must have a
+ matching field in proto, such that proto_is_superset(proto.field,
+ expected_values.field) is true.
- For repeated fields in expected_values, each element in the expected_values
- proto must have a corresponding element at the same index in proto; proto
- may have additional elements.
- """
- if path:
- path += "."
- for spec, expected_value in ir_data_utils.get_set_fields(expected_values):
- name = spec.name
- field_path = "{}{}".format(path, name)
- value = getattr(proto, name)
- if spec.is_dataclass:
- if spec.is_sequence:
- if len(expected_value) > len(value):
- return False, "{}[{}] missing".format(field_path,
- len(getattr(proto, name)))
- for i in range(len(expected_value)):
- result = proto_is_superset(value[i], expected_value[i],
- "{}[{}]".format(field_path, i))
- if not result[0]:
- return result
- else:
- if (expected_values.HasField(name) and
- not proto.HasField(name)):
- return False, "{} missing".format(field_path)
- result = proto_is_superset(value, expected_value, field_path)
- if not result[0]:
- return result
- else:
- # Zero-length repeated fields and not-there repeated fields are "the
- # same."
- if (expected_value != value and
- (not spec.is_sequence or
- len(expected_value))):
- if spec.is_sequence:
- return False, "{} differs: found {}, expected {}".format(
- field_path, list(value), list(expected_value))
+ For repeated fields in expected_values, each element in the expected_values
+ proto must have a corresponding element at the same index in proto; proto
+ may have additional elements.
+ """
+ if path:
+ path += "."
+ for spec, expected_value in ir_data_utils.get_set_fields(expected_values):
+ name = spec.name
+ field_path = "{}{}".format(path, name)
+ value = getattr(proto, name)
+ if spec.is_dataclass:
+ if spec.is_sequence:
+ if len(expected_value) > len(value):
+ return False, "{}[{}] missing".format(
+ field_path, len(getattr(proto, name))
+ )
+ for i in range(len(expected_value)):
+ result = proto_is_superset(
+ value[i], expected_value[i], "{}[{}]".format(field_path, i)
+ )
+ if not result[0]:
+ return result
+ else:
+ if expected_values.HasField(name) and not proto.HasField(name):
+ return False, "{} missing".format(field_path)
+ result = proto_is_superset(value, expected_value, field_path)
+ if not result[0]:
+ return result
else:
- return False, "{} differs: found {}, expected {}".format(
- field_path, value, expected_value)
- return True, ""
+ # Zero-length repeated fields and not-there repeated fields are "the
+ # same."
+ if expected_value != value and (
+ not spec.is_sequence or len(expected_value)
+ ):
+ if spec.is_sequence:
+ return False, "{} differs: found {}, expected {}".format(
+ field_path, list(value), list(expected_value)
+ )
+ else:
+ return False, "{} differs: found {}, expected {}".format(
+ field_path, value, expected_value
+ )
+ return True, ""
def dict_file_reader(file_dict):
- """Returns a callable that retrieves entries from file_dict as files.
+ """Returns a callable that retrieves entries from file_dict as files.
- This can be used to call glue.parse_emboss_file with file text declared
- inline.
+ This can be used to call glue.parse_emboss_file with file text declared
+ inline.
- Arguments:
- file_dict: A dictionary from "file names" to "contents."
+ Arguments:
+ file_dict: A dictionary from "file names" to "contents."
- Returns:
- A callable that can be passed to glue.parse_emboss_file in place of the
- "read" builtin.
- """
+ Returns:
+ A callable that can be passed to glue.parse_emboss_file in place of the
+ "read" builtin.
+ """
- def read(file_name):
- try:
- return file_dict[file_name], None
- except KeyError:
- return None, ["File '{}' not found.".format(file_name)]
+ def read(file_name):
+ try:
+ return file_dict[file_name], None
+ except KeyError:
+ return None, ["File '{}' not found.".format(file_name)]
- return read
+ return read
diff --git a/compiler/util/test_util_test.py b/compiler/util/test_util_test.py
index e82f3c7..58e1ad6 100644
--- a/compiler/util/test_util_test.py
+++ b/compiler/util/test_util_test.py
@@ -22,149 +22,196 @@
class ProtoIsSupersetTest(unittest.TestCase):
- """Tests for test_util.proto_is_superset."""
+ """Tests for test_util.proto_is_superset."""
- def test_superset_extra_optional_field(self):
- self.assertEqual(
- (True, ""),
- test_util.proto_is_superset(
- ir_data.Structure(
- field=[ir_data.Field()],
- source_location=parser_types.parse_location("1:2-3:4")),
- ir_data.Structure(field=[ir_data.Field()])))
+ def test_superset_extra_optional_field(self):
+ self.assertEqual(
+ (True, ""),
+ test_util.proto_is_superset(
+ ir_data.Structure(
+ field=[ir_data.Field()],
+ source_location=parser_types.parse_location("1:2-3:4"),
+ ),
+ ir_data.Structure(field=[ir_data.Field()]),
+ ),
+ )
- def test_superset_extra_repeated_field(self):
- self.assertEqual(
- (True, ""),
- test_util.proto_is_superset(
- ir_data.Structure(
- field=[ir_data.Field(), ir_data.Field()],
- source_location=parser_types.parse_location("1:2-3:4")),
- ir_data.Structure(field=[ir_data.Field()])))
+ def test_superset_extra_repeated_field(self):
+ self.assertEqual(
+ (True, ""),
+ test_util.proto_is_superset(
+ ir_data.Structure(
+ field=[ir_data.Field(), ir_data.Field()],
+ source_location=parser_types.parse_location("1:2-3:4"),
+ ),
+ ir_data.Structure(field=[ir_data.Field()]),
+ ),
+ )
- def test_superset_missing_empty_repeated_field(self):
- self.assertEqual(
- (False, "field[0] missing"),
- test_util.proto_is_superset(
- ir_data.Structure(
- field=[],
- source_location=parser_types.parse_location("1:2-3:4")),
- ir_data.Structure(field=[ir_data.Field(), ir_data.Field()])))
+ def test_superset_missing_empty_repeated_field(self):
+ self.assertEqual(
+ (False, "field[0] missing"),
+ test_util.proto_is_superset(
+ ir_data.Structure(
+ field=[], source_location=parser_types.parse_location("1:2-3:4")
+ ),
+ ir_data.Structure(field=[ir_data.Field(), ir_data.Field()]),
+ ),
+ )
- def test_superset_missing_empty_optional_field(self):
- self.assertEqual((False, "source_location missing"),
- test_util.proto_is_superset(
- ir_data.Structure(field=[]),
- ir_data.Structure(source_location=ir_data.Location())))
+ def test_superset_missing_empty_optional_field(self):
+ self.assertEqual(
+ (False, "source_location missing"),
+ test_util.proto_is_superset(
+ ir_data.Structure(field=[]),
+ ir_data.Structure(source_location=ir_data.Location()),
+ ),
+ )
- def test_array_element_differs(self):
- self.assertEqual(
- (False,
- "field[0].source_location.start.line differs: found 1, expected 2"),
- test_util.proto_is_superset(
- ir_data.Structure(
- field=[ir_data.Field(source_location=parser_types.parse_location(
- "1:2-3:4"))]),
- ir_data.Structure(
- field=[ir_data.Field(source_location=parser_types.parse_location(
- "2:2-3:4"))])))
+ def test_array_element_differs(self):
+ self.assertEqual(
+ (False, "field[0].source_location.start.line differs: found 1, expected 2"),
+ test_util.proto_is_superset(
+ ir_data.Structure(
+ field=[
+ ir_data.Field(
+ source_location=parser_types.parse_location("1:2-3:4")
+ )
+ ]
+ ),
+ ir_data.Structure(
+ field=[
+ ir_data.Field(
+ source_location=parser_types.parse_location("2:2-3:4")
+ )
+ ]
+ ),
+ ),
+ )
- def test_equal(self):
- self.assertEqual(
- (True, ""),
- test_util.proto_is_superset(parser_types.parse_location("1:2-3:4"),
- parser_types.parse_location("1:2-3:4")))
+ def test_equal(self):
+ self.assertEqual(
+ (True, ""),
+ test_util.proto_is_superset(
+ parser_types.parse_location("1:2-3:4"),
+ parser_types.parse_location("1:2-3:4"),
+ ),
+ )
- def test_superset_missing_optional_field(self):
- self.assertEqual(
- (False, "source_location missing"),
- test_util.proto_is_superset(
- ir_data.Structure(field=[ir_data.Field()]),
- ir_data.Structure(
- field=[ir_data.Field()],
- source_location=parser_types.parse_location("1:2-3:4"))))
+ def test_superset_missing_optional_field(self):
+ self.assertEqual(
+ (False, "source_location missing"),
+ test_util.proto_is_superset(
+ ir_data.Structure(field=[ir_data.Field()]),
+ ir_data.Structure(
+ field=[ir_data.Field()],
+ source_location=parser_types.parse_location("1:2-3:4"),
+ ),
+ ),
+ )
- def test_optional_field_differs(self):
- self.assertEqual((False, "end.line differs: found 4, expected 3"),
- test_util.proto_is_superset(
- parser_types.parse_location("1:2-4:4"),
- parser_types.parse_location("1:2-3:4")))
+ def test_optional_field_differs(self):
+ self.assertEqual(
+ (False, "end.line differs: found 4, expected 3"),
+ test_util.proto_is_superset(
+ parser_types.parse_location("1:2-4:4"),
+ parser_types.parse_location("1:2-3:4"),
+ ),
+ )
- def test_non_message_repeated_field_equal(self):
- self.assertEqual((True, ""),
- test_util.proto_is_superset(
- ir_data.CanonicalName(object_path=[]),
- ir_data.CanonicalName(object_path=[])))
+ def test_non_message_repeated_field_equal(self):
+ self.assertEqual(
+ (True, ""),
+ test_util.proto_is_superset(
+ ir_data.CanonicalName(object_path=[]),
+ ir_data.CanonicalName(object_path=[]),
+ ),
+ )
- def test_non_message_repeated_field_missing_element(self):
- self.assertEqual(
- (False, "object_path differs: found {none!r}, expected {a!r}".format(
- none=[],
- a=[u"a"])),
- test_util.proto_is_superset(
- ir_data.CanonicalName(object_path=[]),
- ir_data.CanonicalName(object_path=[u"a"])))
+ def test_non_message_repeated_field_missing_element(self):
+ self.assertEqual(
+ (
+ False,
+ "object_path differs: found {none!r}, expected {a!r}".format(
+ none=[], a=["a"]
+ ),
+ ),
+ test_util.proto_is_superset(
+ ir_data.CanonicalName(object_path=[]),
+ ir_data.CanonicalName(object_path=["a"]),
+ ),
+ )
- def test_non_message_repeated_field_element_differs(self):
- self.assertEqual(
- (False, "object_path differs: found {aa!r}, expected {ab!r}".format(
- aa=[u"a", u"a"],
- ab=[u"a", u"b"])),
- test_util.proto_is_superset(
- ir_data.CanonicalName(object_path=[u"a", u"a"]),
- ir_data.CanonicalName(object_path=[u"a", u"b"])))
+ def test_non_message_repeated_field_element_differs(self):
+ self.assertEqual(
+ (
+ False,
+ "object_path differs: found {aa!r}, expected {ab!r}".format(
+ aa=["a", "a"], ab=["a", "b"]
+ ),
+ ),
+ test_util.proto_is_superset(
+ ir_data.CanonicalName(object_path=["a", "a"]),
+ ir_data.CanonicalName(object_path=["a", "b"]),
+ ),
+ )
- def test_non_message_repeated_field_extra_element(self):
- # For repeated fields of int/bool/str values, the entire list is treated as
- # an atomic unit, and should be equal.
- self.assertEqual(
- (False,
- "object_path differs: found {!r}, expected {!r}".format(
- [u"a", u"a"], [u"a"])),
- test_util.proto_is_superset(
- ir_data.CanonicalName(object_path=["a", "a"]),
- ir_data.CanonicalName(object_path=["a"])))
+ def test_non_message_repeated_field_extra_element(self):
+ # For repeated fields of int/bool/str values, the entire list is treated as
+ # an atomic unit, and should be equal.
+ self.assertEqual(
+ (
+ False,
+ "object_path differs: found {!r}, expected {!r}".format(
+ ["a", "a"], ["a"]
+ ),
+ ),
+ test_util.proto_is_superset(
+ ir_data.CanonicalName(object_path=["a", "a"]),
+ ir_data.CanonicalName(object_path=["a"]),
+ ),
+ )
- def test_non_message_repeated_field_no_expected_value(self):
- # When a repeated field is empty, it is the same as if it were entirely
- # missing -- there is no way to differentiate those two conditions.
- self.assertEqual((True, ""),
- test_util.proto_is_superset(
- ir_data.CanonicalName(object_path=["a", "a"]),
- ir_data.CanonicalName(object_path=[])))
+ def test_non_message_repeated_field_no_expected_value(self):
+ # When a repeated field is empty, it is the same as if it were entirely
+ # missing -- there is no way to differentiate those two conditions.
+ self.assertEqual(
+ (True, ""),
+ test_util.proto_is_superset(
+ ir_data.CanonicalName(object_path=["a", "a"]),
+ ir_data.CanonicalName(object_path=[]),
+ ),
+ )
class DictFileReaderTest(unittest.TestCase):
- """Tests for dict_file_reader."""
+ """Tests for dict_file_reader."""
- def test_empty_dict(self):
- reader = test_util.dict_file_reader({})
- self.assertEqual((None, ["File 'anything' not found."]), reader("anything"))
- self.assertEqual((None, ["File '' not found."]), reader(""))
+ def test_empty_dict(self):
+ reader = test_util.dict_file_reader({})
+ self.assertEqual((None, ["File 'anything' not found."]), reader("anything"))
+ self.assertEqual((None, ["File '' not found."]), reader(""))
- def test_one_element_dict(self):
- reader = test_util.dict_file_reader({"m": "abc"})
- self.assertEqual((None, ["File 'not_there' not found."]),
- reader("not_there"))
- self.assertEqual((None, ["File '' not found."]), reader(""))
- self.assertEqual(("abc", None), reader("m"))
+ def test_one_element_dict(self):
+ reader = test_util.dict_file_reader({"m": "abc"})
+ self.assertEqual((None, ["File 'not_there' not found."]), reader("not_there"))
+ self.assertEqual((None, ["File '' not found."]), reader(""))
+ self.assertEqual(("abc", None), reader("m"))
- def test_two_element_dict(self):
- reader = test_util.dict_file_reader({"m": "abc", "n": "def"})
- self.assertEqual((None, ["File 'not_there' not found."]),
- reader("not_there"))
- self.assertEqual((None, ["File '' not found."]), reader(""))
- self.assertEqual(("abc", None), reader("m"))
- self.assertEqual(("def", None), reader("n"))
+ def test_two_element_dict(self):
+ reader = test_util.dict_file_reader({"m": "abc", "n": "def"})
+ self.assertEqual((None, ["File 'not_there' not found."]), reader("not_there"))
+ self.assertEqual((None, ["File '' not found."]), reader(""))
+ self.assertEqual(("abc", None), reader("m"))
+ self.assertEqual(("def", None), reader("n"))
- def test_dict_with_empty_key(self):
- reader = test_util.dict_file_reader({"m": "abc", "": "def"})
- self.assertEqual((None, ["File 'not_there' not found."]),
- reader("not_there"))
- self.assertEqual((None, ["File 'None' not found."]), reader(None))
- self.assertEqual(("abc", None), reader("m"))
- self.assertEqual(("def", None), reader(""))
+ def test_dict_with_empty_key(self):
+ reader = test_util.dict_file_reader({"m": "abc", "": "def"})
+ self.assertEqual((None, ["File 'not_there' not found."]), reader("not_there"))
+ self.assertEqual((None, ["File 'None' not found."]), reader(None))
+ self.assertEqual(("abc", None), reader("m"))
+ self.assertEqual(("def", None), reader(""))
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/compiler/util/traverse_ir.py b/compiler/util/traverse_ir.py
index 93ffdc8..a6c3dda 100644
--- a/compiler/util/traverse_ir.py
+++ b/compiler/util/traverse_ir.py
@@ -23,373 +23,418 @@
class _FunctionCaller:
- """Provides a template for setting up a generic call to a function.
+ """Provides a template for setting up a generic call to a function.
- The function parameters are inspected at run-time to build up a set of valid
- and required arguments. When invoking the function unneccessary parameters
- will be trimmed out. If arguments are missing an assertion will be triggered.
+ The function parameters are inspected at run-time to build up a set of valid
+ and required arguments. When invoking the function unneccessary parameters
+ will be trimmed out. If arguments are missing an assertion will be triggered.
- This is currently limited to functions that have at least one positional
- parameter.
+ This is currently limited to functions that have at least one positional
+ parameter.
- Example usage:
- ```
- def func_1(a, b, c=2): pass
- def func_2(a, d): pass
- caller_1 = _FunctionCaller(func_1)
- caller_2 = _FunctionCaller(func_2)
- generic_params = {"b": 2, "c": 3, "d": 4}
+ Example usage:
+ ```
+ def func_1(a, b, c=2): pass
+ def func_2(a, d): pass
+ caller_1 = _FunctionCaller(func_1)
+ caller_2 = _FunctionCaller(func_2)
+ generic_params = {"b": 2, "c": 3, "d": 4}
- # Equivalent of: func_1(a, b=2, c=3)
- caller_1.invoke(a, generic_params)
+ # Equivalent of: func_1(a, b=2, c=3)
+ caller_1.invoke(a, generic_params)
- # Equivalent of: func_2(a, d=4)
- caller_2.invoke(a, generic_params)
- """
+ # Equivalent of: func_2(a, d=4)
+ caller_2.invoke(a, generic_params)
+ """
- def __init__(self, function):
- self.function = function
- self.needs_filtering = True
- self.valid_arg_names = set()
- self.required_arg_names = set()
+ def __init__(self, function):
+ self.function = function
+ self.needs_filtering = True
+ self.valid_arg_names = set()
+ self.required_arg_names = set()
- argspec = inspect.getfullargspec(function)
- if argspec.varkw:
- # If the function accepts a kwargs parameter, then it will accept all
- # arguments.
- # Note: this isn't technically true if one of the keyword arguments has the
- # same name as one of the positional arguments.
- self.needs_filtering = False
- else:
- # argspec.args is a list of all parameter names excluding keyword only
- # args. The first element is our required positional_arg and should be
- # ignored.
- args = argspec.args[1:]
- self.valid_arg_names.update(args)
+ argspec = inspect.getfullargspec(function)
+ if argspec.varkw:
+ # If the function accepts a kwargs parameter, then it will accept all
+ # arguments.
+ # Note: this isn't technically true if one of the keyword arguments has the
+ # same name as one of the positional arguments.
+ self.needs_filtering = False
+ else:
+ # argspec.args is a list of all parameter names excluding keyword only
+ # args. The first element is our required positional_arg and should be
+ # ignored.
+ args = argspec.args[1:]
+ self.valid_arg_names.update(args)
- # args.kwonlyargs gives us the list of keyword only args which are
- # also valid.
- self.valid_arg_names.update(argspec.kwonlyargs)
+ # args.kwonlyargs gives us the list of keyword only args which are
+ # also valid.
+ self.valid_arg_names.update(argspec.kwonlyargs)
- # Required args are positional arguments that don't have defaults.
- # Keyword only args are always optional and can be ignored. Args with
- # defaults are the last elements of the argsepec.args list and should
- # be ignored.
- if argspec.defaults:
- # Trim the arguments with defaults.
- args = args[: -len(argspec.defaults)]
- self.required_arg_names.update(args)
+ # Required args are positional arguments that don't have defaults.
+ # Keyword only args are always optional and can be ignored. Args with
+ # defaults are the last elements of the argsepec.args list and should
+ # be ignored.
+ if argspec.defaults:
+ # Trim the arguments with defaults.
+ args = args[: -len(argspec.defaults)]
+ self.required_arg_names.update(args)
- def invoke(self, positional_arg, keyword_args):
- """Invokes the function with the given args."""
- if self.needs_filtering:
- # Trim to just recognized args.
- matched_args = {
- k: v for k, v in keyword_args.items() if k in self.valid_arg_names
- }
- # Check if any required args are missing.
- missing_args = self.required_arg_names.difference(matched_args.keys())
- assert not missing_args, (
- f"Attempting to call '{self.function.__name__}'; "
- f"missing {missing_args} (have {set(keyword_args.keys())})"
- )
- keyword_args = matched_args
+ def invoke(self, positional_arg, keyword_args):
+ """Invokes the function with the given args."""
+ if self.needs_filtering:
+ # Trim to just recognized args.
+ matched_args = {
+ k: v for k, v in keyword_args.items() if k in self.valid_arg_names
+ }
+ # Check if any required args are missing.
+ missing_args = self.required_arg_names.difference(matched_args.keys())
+ assert not missing_args, (
+ f"Attempting to call '{self.function.__name__}'; "
+ f"missing {missing_args} (have {set(keyword_args.keys())})"
+ )
+ keyword_args = matched_args
- return self.function(positional_arg, **keyword_args)
+ return self.function(positional_arg, **keyword_args)
@simple_memoizer.memoize
def _memoized_caller(function):
- default_lambda_name = (lambda: None).__name__
- assert (
- callable(function) and not function.__name__ == default_lambda_name
- ), "For performance reasons actions must be defined as static functions"
- return _FunctionCaller(function)
+ default_lambda_name = (lambda: None).__name__
+ assert (
+ callable(function) and not function.__name__ == default_lambda_name
+ ), "For performance reasons actions must be defined as static functions"
+ return _FunctionCaller(function)
def _call_with_optional_args(function, positional_arg, keyword_args):
- """Calls function with whatever keyword_args it will accept."""
- caller = _memoized_caller(function)
- return caller.invoke(positional_arg, keyword_args)
+ """Calls function with whatever keyword_args it will accept."""
+ caller = _memoized_caller(function)
+ return caller.invoke(positional_arg, keyword_args)
-def _fast_traverse_proto_top_down(proto, incidental_actions, pattern,
- skip_descendants_of, action, parameters):
- """Traverses an IR, calling `action` on some nodes."""
+def _fast_traverse_proto_top_down(
+ proto, incidental_actions, pattern, skip_descendants_of, action, parameters
+):
+ """Traverses an IR, calling `action` on some nodes."""
- # Parameters are scoped to the branch of the tree, so make a copy here, before
- # any action or incidental_action can update them.
- parameters = parameters.copy()
+ # Parameters are scoped to the branch of the tree, so make a copy here, before
+ # any action or incidental_action can update them.
+ parameters = parameters.copy()
- # If there is an incidental action for this node type, run it.
- if type(proto) in incidental_actions: # pylint: disable=unidiomatic-typecheck
- for incidental_action in incidental_actions[type(proto)]:
- parameters.update(_call_with_optional_args(
- incidental_action, proto, parameters) or {})
+ # If there is an incidental action for this node type, run it.
+ if type(proto) in incidental_actions: # pylint: disable=unidiomatic-typecheck
+ for incidental_action in incidental_actions[type(proto)]:
+ parameters.update(
+ _call_with_optional_args(incidental_action, proto, parameters) or {}
+ )
- # If we are at the end of pattern, check to see if we should call action.
- if len(pattern) == 1:
- new_pattern = pattern
- if pattern[0] == type(proto):
- parameters.update(
- _call_with_optional_args(action, proto, parameters) or {})
- else:
- # Otherwise, if this node's type matches the head of pattern, recurse with
- # the tail of the pattern.
- if pattern[0] == type(proto):
- new_pattern = pattern[1:]
+ # If we are at the end of pattern, check to see if we should call action.
+ if len(pattern) == 1:
+ new_pattern = pattern
+ if pattern[0] == type(proto):
+ parameters.update(_call_with_optional_args(action, proto, parameters) or {})
else:
- new_pattern = pattern
+ # Otherwise, if this node's type matches the head of pattern, recurse with
+ # the tail of the pattern.
+ if pattern[0] == type(proto):
+ new_pattern = pattern[1:]
+ else:
+ new_pattern = pattern
- # If the current node's type is one of the types whose branch should be
- # skipped, then bail. This has to happen after `action` is called, because
- # clients rely on being able to, e.g., get a callback for the "root"
- # Expression without getting callbacks for every sub-Expression.
- # pylint: disable=unidiomatic-typecheck
- if type(proto) in skip_descendants_of:
- return
+ # If the current node's type is one of the types whose branch should be
+ # skipped, then bail. This has to happen after `action` is called, because
+ # clients rely on being able to, e.g., get a callback for the "root"
+ # Expression without getting callbacks for every sub-Expression.
+ # pylint: disable=unidiomatic-typecheck
+ if type(proto) in skip_descendants_of:
+ return
- # Otherwise, recurse. _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET tells us, given
- # the current node's type and the current target type, which fields to check.
- singular_fields, repeated_fields = _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET[
- type(proto), new_pattern[0]]
- for member_name in singular_fields:
- if proto.HasField(member_name):
- _fast_traverse_proto_top_down(getattr(proto, member_name),
- incidental_actions, new_pattern,
- skip_descendants_of, action, parameters)
- for member_name in repeated_fields:
- for array_element in getattr(proto, member_name) or []:
- _fast_traverse_proto_top_down(array_element, incidental_actions,
- new_pattern, skip_descendants_of, action,
- parameters)
+ # Otherwise, recurse. _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET tells us, given
+ # the current node's type and the current target type, which fields to check.
+ singular_fields, repeated_fields = _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET[
+ type(proto), new_pattern[0]
+ ]
+ for member_name in singular_fields:
+ if proto.HasField(member_name):
+ _fast_traverse_proto_top_down(
+ getattr(proto, member_name),
+ incidental_actions,
+ new_pattern,
+ skip_descendants_of,
+ action,
+ parameters,
+ )
+ for member_name in repeated_fields:
+ for array_element in getattr(proto, member_name) or []:
+ _fast_traverse_proto_top_down(
+ array_element,
+ incidental_actions,
+ new_pattern,
+ skip_descendants_of,
+ action,
+ parameters,
+ )
def _fields_to_scan_by_current_and_target():
- """Generates _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET."""
- # In order to avoid spending a *lot* of time just walking the IR, this
- # function sets up a dict that allows `_fast_traverse_proto_top_down()` to
- # skip traversing large portions of the IR, depending on what node types it is
- # targeting.
- #
- # Without this branch culling scheme, the Emboss front end (at time of
- # writing) spends roughly 70% (19s out of 31s) of its time just walking the
- # IR. With branch culling, that goes down to 6% (0.7s out of 12.2s).
+ """Generates _FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET."""
+ # In order to avoid spending a *lot* of time just walking the IR, this
+ # function sets up a dict that allows `_fast_traverse_proto_top_down()` to
+ # skip traversing large portions of the IR, depending on what node types it is
+ # targeting.
+ #
+ # Without this branch culling scheme, the Emboss front end (at time of
+ # writing) spends roughly 70% (19s out of 31s) of its time just walking the
+ # IR. With branch culling, that goes down to 6% (0.7s out of 12.2s).
- # type_to_fields is a map of types to maps of field names to field types.
- # That is, type_to_fields[ir_data.Module]["type"] == ir_data.AddressableUnit.
- type_to_fields = {}
+ # type_to_fields is a map of types to maps of field names to field types.
+ # That is, type_to_fields[ir_data.Module]["type"] == ir_data.AddressableUnit.
+ type_to_fields = {}
- # Later, we need to know which fields are singular and which are repeated,
- # because the access methods are not uniform. This maps (type, field_name)
- # tuples to descriptor labels: type_fields_to_cardinality[ir_data.Module,
- # "type"] == ir_data.Repeated.
- type_fields_to_cardinality = {}
+ # Later, we need to know which fields are singular and which are repeated,
+ # because the access methods are not uniform. This maps (type, field_name)
+ # tuples to descriptor labels: type_fields_to_cardinality[ir_data.Module,
+ # "type"] == ir_data.Repeated.
+ type_fields_to_cardinality = {}
- # Fill out the above maps by recursively walking the IR type tree, starting
- # from the root.
- types_to_check = [ir_data.EmbossIr]
- while types_to_check:
- type_to_check = types_to_check.pop()
- if type_to_check in type_to_fields:
- continue
- fields = {}
- for field_name, field_type in ir_data_utils.field_specs(type_to_check).items():
- if field_type.is_dataclass:
- fields[field_name] = field_type.data_type
- types_to_check.append(field_type.data_type)
- type_fields_to_cardinality[type_to_check, field_name] = (
- field_type.container)
- type_to_fields[type_to_check] = fields
+ # Fill out the above maps by recursively walking the IR type tree, starting
+ # from the root.
+ types_to_check = [ir_data.EmbossIr]
+ while types_to_check:
+ type_to_check = types_to_check.pop()
+ if type_to_check in type_to_fields:
+ continue
+ fields = {}
+ for field_name, field_type in ir_data_utils.field_specs(type_to_check).items():
+ if field_type.is_dataclass:
+ fields[field_name] = field_type.data_type
+ types_to_check.append(field_type.data_type)
+ type_fields_to_cardinality[type_to_check, field_name] = (
+ field_type.container
+ )
+ type_to_fields[type_to_check] = fields
- # type_to_descendant_types is a map of all types that can be reached from a
- # particular type. After the setup, type_to_descendant_types[ir_data.EmbossIr]
- # == set(<all types>) and type_to_descendant_types[ir_data.Reference] ==
- # {ir_data.CanonicalName, ir_data.Word, ir_data.Location} and
- # type_to_descendant_types[ir_data.Word] == set().
- #
- # The while loop basically ors in the known descendants of each known
- # descendant of each type until the dict stops changing, which is a bit
- # brute-force, but in practice only iterates a few times.
- type_to_descendant_types = {}
- for parent_type, field_map in type_to_fields.items():
- type_to_descendant_types[parent_type] = set(field_map.values())
- previous_map = {}
- while type_to_descendant_types != previous_map:
- # In order to check the previous iteration against the current iteration, it
- # is necessary to make a two-level copy. Otherwise, the updates to the
- # values will also update previous_map's values, which causes the loop to
- # exit prematurely.
- previous_map = {k: set(v) for k, v in type_to_descendant_types.items()}
- for ancestor_type, descendents in previous_map.items():
- for descendent in descendents:
- type_to_descendant_types[ancestor_type] |= previous_map[descendent]
+ # type_to_descendant_types is a map of all types that can be reached from a
+ # particular type. After the setup, type_to_descendant_types[ir_data.EmbossIr]
+ # == set(<all types>) and type_to_descendant_types[ir_data.Reference] ==
+ # {ir_data.CanonicalName, ir_data.Word, ir_data.Location} and
+ # type_to_descendant_types[ir_data.Word] == set().
+ #
+ # The while loop basically ors in the known descendants of each known
+ # descendant of each type until the dict stops changing, which is a bit
+ # brute-force, but in practice only iterates a few times.
+ type_to_descendant_types = {}
+ for parent_type, field_map in type_to_fields.items():
+ type_to_descendant_types[parent_type] = set(field_map.values())
+ previous_map = {}
+ while type_to_descendant_types != previous_map:
+ # In order to check the previous iteration against the current iteration, it
+ # is necessary to make a two-level copy. Otherwise, the updates to the
+ # values will also update previous_map's values, which causes the loop to
+ # exit prematurely.
+ previous_map = {k: set(v) for k, v in type_to_descendant_types.items()}
+ for ancestor_type, descendents in previous_map.items():
+ for descendent in descendents:
+ type_to_descendant_types[ancestor_type] |= previous_map[descendent]
- # Finally, we have all of the information we need to make the map we really
- # want: given a current node type and a target node type, which fields should
- # be checked? (This implicitly skips fields that *can't* contain the target
- # type.)
- fields_to_scan_by_current_and_target = {}
- for current_node_type in type_to_fields:
- for target_node_type in type_to_fields:
- singular_fields_to_scan = []
- repeated_fields_to_scan = []
- for field_name, field_type in type_to_fields[current_node_type].items():
- # If the target node type cannot contain another instance of itself, it
- # is still necessary to scan fields that have the actual target type.
- if (target_node_type == field_type or
- target_node_type in type_to_descendant_types[field_type]):
- # Singular and repeated fields go to different lists, so that they can
- # be handled separately.
- if (type_fields_to_cardinality[current_node_type, field_name] is not
- ir_data_fields.FieldContainer.LIST):
- singular_fields_to_scan.append(field_name)
- else:
- repeated_fields_to_scan.append(field_name)
- fields_to_scan_by_current_and_target[
- current_node_type, target_node_type] = (
- singular_fields_to_scan, repeated_fields_to_scan)
- return fields_to_scan_by_current_and_target
+ # Finally, we have all of the information we need to make the map we really
+ # want: given a current node type and a target node type, which fields should
+ # be checked? (This implicitly skips fields that *can't* contain the target
+ # type.)
+ fields_to_scan_by_current_and_target = {}
+ for current_node_type in type_to_fields:
+ for target_node_type in type_to_fields:
+ singular_fields_to_scan = []
+ repeated_fields_to_scan = []
+ for field_name, field_type in type_to_fields[current_node_type].items():
+ # If the target node type cannot contain another instance of itself, it
+ # is still necessary to scan fields that have the actual target type.
+ if (
+ target_node_type == field_type
+ or target_node_type in type_to_descendant_types[field_type]
+ ):
+ # Singular and repeated fields go to different lists, so that they can
+ # be handled separately.
+ if (
+ type_fields_to_cardinality[current_node_type, field_name]
+ is not ir_data_fields.FieldContainer.LIST
+ ):
+ singular_fields_to_scan.append(field_name)
+ else:
+ repeated_fields_to_scan.append(field_name)
+ fields_to_scan_by_current_and_target[
+ current_node_type, target_node_type
+ ] = (singular_fields_to_scan, repeated_fields_to_scan)
+ return fields_to_scan_by_current_and_target
_FIELDS_TO_SCAN_BY_CURRENT_AND_TARGET = _fields_to_scan_by_current_and_target()
+
def _emboss_ir_action(ir):
- return {"ir": ir}
+ return {"ir": ir}
+
def _module_action(m):
- return {"source_file_name": m.source_file_name}
+ return {"source_file_name": m.source_file_name}
+
def _type_definition_action(t):
- return {"type_definition": t}
+ return {"type_definition": t}
+
def _field_action(f):
- return {"field": f}
-
-def fast_traverse_ir_top_down(ir, pattern, action, incidental_actions=None,
- skip_descendants_of=(), parameters=None):
- """Traverses an IR from the top down, executing the given actions.
-
- `fast_traverse_ir_top_down` walks the given IR in preorder traversal,
- specifically looking for nodes whose path from the root of the tree matches
- `pattern`. For every node which matches `pattern`, `action` will be called.
-
- `pattern` is just a list of node types. For example, to execute `print` on
- every `ir_data.Word` in the IR:
-
- fast_traverse_ir_top_down(ir, [ir_data.Word], print)
-
- If more than one type is specified, then each one must be found inside the
- previous. For example, to print only the Words inside of import statements:
-
- fast_traverse_ir_top_down(ir, [ir_data.Import, ir_data.Word], print)
-
- The optional arguments provide additional control.
-
- `skip_descendants_of` is a list of types that should be treated as if they are
- leaf nodes when they are encountered. That is, traversal will skip any
- nodes with any ancestor node whose type is in `skip_descendants_of`. For
- example, to `do_something` only on outermost `Expression`s:
-
- fast_traverse_ir_top_down(ir, [ir_data.Expression], do_something,
- skip_descendants_of={ir_data.Expression})
-
- `parameters` specifies a dictionary of initial parameters which can be passed
- as arguments to `action` and `incidental_actions`. Note that the parameters
- can be overridden for parts of the tree by `action` and `incidental_actions`.
- Parameters can be used to set an object which may be updated by `action`, such
- as a list of errors generated by some check in `action`:
-
- def check_structure(structure, errors):
- if structure_is_bad(structure):
- errors.append(error_for_structure(structure))
-
- errors = []
- fast_traverse_ir_top_down(ir, [ir_data.Structure], check_structure,
- parameters={"errors": errors})
- if errors:
- print("Errors: {}".format(errors))
- sys.exit(1)
-
- `incidental_actions` is a map from node types to functions (or tuples of
- functions or lists of functions) which should be called on those nodes.
- Because `fast_traverse_ir_top_down` may skip branches that can't contain
- `pattern`, functions in `incidental_actions` should generally not have any
- side effects: instead, they may return a dictionary, which will be used to
- override `parameters` for any children of the node they were called on. For
- example:
-
- def do_something(expression, field_name=None):
- if field_name:
- print("Found {} inside {}".format(expression, field_name))
- else:
- print("Found {} not in any field".format(expression))
-
- fast_traverse_ir_top_down(
- ir, [ir_data.Expression], do_something,
- incidental_actions={ir_data.Field: lambda f: {"field_name": f.name}})
-
- (The `action` may also return a dict in the same way.)
-
- A few `incidental_actions` are built into `fast_traverse_ir_top_down`, so
- that certain parameters are contextually available with well-known names:
-
- ir: The complete IR (the root ir_data.EmbossIr node).
- source_file_name: The file name from which the current node was sourced.
- type_definition: The most-immediate ancestor type definition.
- field: The field containing the current node, if any.
-
- Arguments:
- ir: An ir_data.Ir object to walk.
- pattern: A list of node types to match.
- action: A callable, which will be called on nodes matching `pattern`.
- incidental_actions: A dict of node types to callables, which can be used to
- set new parameters for `action` for part of the IR tree.
- skip_descendants_of: A list of types whose children should be skipped when
- traversing `ir`.
- parameters: A list of top-level parameters.
-
- Returns:
- None
- """
- all_incidental_actions = {
- ir_data.EmbossIr: [_emboss_ir_action],
- ir_data.Module: [_module_action],
- ir_data.TypeDefinition: [_type_definition_action],
- ir_data.Field: [_field_action],
- }
- if incidental_actions:
- for key, incidental_action in incidental_actions.items():
- if not isinstance(incidental_action, (list, tuple)):
- incidental_action = [incidental_action]
- all_incidental_actions.setdefault(key, []).extend(incidental_action)
- _fast_traverse_proto_top_down(ir, all_incidental_actions, pattern,
- skip_descendants_of, action, parameters or {})
+ return {"field": f}
-def fast_traverse_node_top_down(node, pattern, action, incidental_actions=None,
- skip_descendants_of=(), parameters=None):
- """Traverse a subtree of an IR, executing the given actions.
+def fast_traverse_ir_top_down(
+ ir,
+ pattern,
+ action,
+ incidental_actions=None,
+ skip_descendants_of=(),
+ parameters=None,
+):
+ """Traverses an IR from the top down, executing the given actions.
- fast_traverse_node_top_down is like fast_traverse_ir_top_down, except that:
+ `fast_traverse_ir_top_down` walks the given IR in preorder traversal,
+ specifically looking for nodes whose path from the root of the tree matches
+ `pattern`. For every node which matches `pattern`, `action` will be called.
- It may be called on a subtree, instead of the top of the IR.
+ `pattern` is just a list of node types. For example, to execute `print` on
+ every `ir_data.Word` in the IR:
- It does not have any built-in incidental actions.
+ fast_traverse_ir_top_down(ir, [ir_data.Word], print)
- Arguments:
- node: An ir_data.Ir object to walk.
- pattern: A list of node types to match.
- action: A callable, which will be called on nodes matching `pattern`.
- incidental_actions: A dict of node types to callables, which can be used to
- set new parameters for `action` for part of the IR tree.
- skip_descendants_of: A list of types whose children should be skipped when
- traversing `node`.
- parameters: A list of top-level parameters.
+ If more than one type is specified, then each one must be found inside the
+ previous. For example, to print only the Words inside of import statements:
- Returns:
- None
- """
- _fast_traverse_proto_top_down(node, incidental_actions or {}, pattern,
- skip_descendants_of or {}, action,
- parameters or {})
+ fast_traverse_ir_top_down(ir, [ir_data.Import, ir_data.Word], print)
+
+ The optional arguments provide additional control.
+
+ `skip_descendants_of` is a list of types that should be treated as if they are
+ leaf nodes when they are encountered. That is, traversal will skip any
+ nodes with any ancestor node whose type is in `skip_descendants_of`. For
+ example, to `do_something` only on outermost `Expression`s:
+
+ fast_traverse_ir_top_down(ir, [ir_data.Expression], do_something,
+ skip_descendants_of={ir_data.Expression})
+
+ `parameters` specifies a dictionary of initial parameters which can be passed
+ as arguments to `action` and `incidental_actions`. Note that the parameters
+ can be overridden for parts of the tree by `action` and `incidental_actions`.
+ Parameters can be used to set an object which may be updated by `action`, such
+ as a list of errors generated by some check in `action`:
+
+ def check_structure(structure, errors):
+ if structure_is_bad(structure):
+ errors.append(error_for_structure(structure))
+
+ errors = []
+ fast_traverse_ir_top_down(ir, [ir_data.Structure], check_structure,
+ parameters={"errors": errors})
+ if errors:
+ print("Errors: {}".format(errors))
+ sys.exit(1)
+
+ `incidental_actions` is a map from node types to functions (or tuples of
+ functions or lists of functions) which should be called on those nodes.
+ Because `fast_traverse_ir_top_down` may skip branches that can't contain
+ `pattern`, functions in `incidental_actions` should generally not have any
+ side effects: instead, they may return a dictionary, which will be used to
+ override `parameters` for any children of the node they were called on. For
+ example:
+
+ def do_something(expression, field_name=None):
+ if field_name:
+ print("Found {} inside {}".format(expression, field_name))
+ else:
+ print("Found {} not in any field".format(expression))
+
+ fast_traverse_ir_top_down(
+ ir, [ir_data.Expression], do_something,
+ incidental_actions={ir_data.Field: lambda f: {"field_name": f.name}})
+
+ (The `action` may also return a dict in the same way.)
+
+ A few `incidental_actions` are built into `fast_traverse_ir_top_down`, so
+ that certain parameters are contextually available with well-known names:
+
+ ir: The complete IR (the root ir_data.EmbossIr node).
+ source_file_name: The file name from which the current node was sourced.
+ type_definition: The most-immediate ancestor type definition.
+ field: The field containing the current node, if any.
+
+ Arguments:
+ ir: An ir_data.Ir object to walk.
+ pattern: A list of node types to match.
+ action: A callable, which will be called on nodes matching `pattern`.
+ incidental_actions: A dict of node types to callables, which can be used to
+ set new parameters for `action` for part of the IR tree.
+ skip_descendants_of: A list of types whose children should be skipped when
+ traversing `ir`.
+ parameters: A list of top-level parameters.
+
+ Returns:
+ None
+ """
+ all_incidental_actions = {
+ ir_data.EmbossIr: [_emboss_ir_action],
+ ir_data.Module: [_module_action],
+ ir_data.TypeDefinition: [_type_definition_action],
+ ir_data.Field: [_field_action],
+ }
+ if incidental_actions:
+ for key, incidental_action in incidental_actions.items():
+ if not isinstance(incidental_action, (list, tuple)):
+ incidental_action = [incidental_action]
+ all_incidental_actions.setdefault(key, []).extend(incidental_action)
+ _fast_traverse_proto_top_down(
+ ir,
+ all_incidental_actions,
+ pattern,
+ skip_descendants_of,
+ action,
+ parameters or {},
+ )
+
+
+def fast_traverse_node_top_down(
+ node,
+ pattern,
+ action,
+ incidental_actions=None,
+ skip_descendants_of=(),
+ parameters=None,
+):
+ """Traverse a subtree of an IR, executing the given actions.
+
+ fast_traverse_node_top_down is like fast_traverse_ir_top_down, except that:
+
+ It may be called on a subtree, instead of the top of the IR.
+
+ It does not have any built-in incidental actions.
+
+ Arguments:
+ node: An ir_data.Ir object to walk.
+ pattern: A list of node types to match.
+ action: A callable, which will be called on nodes matching `pattern`.
+ incidental_actions: A dict of node types to callables, which can be used to
+ set new parameters for `action` for part of the IR tree.
+ skip_descendants_of: A list of types whose children should be skipped when
+ traversing `node`.
+ parameters: A list of top-level parameters.
+
+ Returns:
+ None
+ """
+ _fast_traverse_proto_top_down(
+ node,
+ incidental_actions or {},
+ pattern,
+ skip_descendants_of or {},
+ action,
+ parameters or {},
+ )
diff --git a/compiler/util/traverse_ir_test.py b/compiler/util/traverse_ir_test.py
index ff54d63..504eb88 100644
--- a/compiler/util/traverse_ir_test.py
+++ b/compiler/util/traverse_ir_test.py
@@ -22,7 +22,9 @@
from compiler.util import ir_data_utils
from compiler.util import traverse_ir
-_EXAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, """{
+_EXAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(
+ ir_data.EmbossIr,
+ """{
"module": [
{
"type": [
@@ -175,165 +177,252 @@
"source_file_name": ""
}
]
-}""")
+}""",
+)
def _count_entries(sequence):
- counts = collections.Counter()
- for entry in sequence:
- counts[entry] += 1
- return counts
+ counts = collections.Counter()
+ for entry in sequence:
+ counts[entry] += 1
+ return counts
def _record_constant(constant, constant_list):
- constant_list.append(int(constant.value))
+ constant_list.append(int(constant.value))
def _record_field_name_and_constant(constant, constant_list, field):
- constant_list.append((field.name.name.text, int(constant.value)))
+ constant_list.append((field.name.name.text, int(constant.value)))
def _record_file_name_and_constant(constant, constant_list, source_file_name):
- constant_list.append((source_file_name, int(constant.value)))
+ constant_list.append((source_file_name, int(constant.value)))
-def _record_location_parameter_and_constant(constant, constant_list,
- location=None):
- constant_list.append((location, int(constant.value)))
+def _record_location_parameter_and_constant(constant, constant_list, location=None):
+ constant_list.append((location, int(constant.value)))
def _record_kind_and_constant(constant, constant_list, type_definition):
- if type_definition.HasField("enumeration"):
- constant_list.append(("enumeration", int(constant.value)))
- elif type_definition.HasField("structure"):
- constant_list.append(("structure", int(constant.value)))
- elif type_definition.HasField("external"):
- constant_list.append(("external", int(constant.value)))
- else:
- assert False, "Shouldn't be here."
+ if type_definition.HasField("enumeration"):
+ constant_list.append(("enumeration", int(constant.value)))
+ elif type_definition.HasField("structure"):
+ constant_list.append(("structure", int(constant.value)))
+ elif type_definition.HasField("external"):
+ constant_list.append(("external", int(constant.value)))
+ else:
+ assert False, "Shouldn't be here."
class TraverseIrTest(unittest.TestCase):
- def test_filter_on_type(self):
- constants = []
- traverse_ir.fast_traverse_ir_top_down(
- _EXAMPLE_IR, [ir_data.NumericConstant], _record_constant,
- parameters={"constant_list": constants})
- self.assertEqual(
- _count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320, 1, 1, 1, 64]),
- _count_entries(constants))
+ def test_filter_on_type(self):
+ constants = []
+ traverse_ir.fast_traverse_ir_top_down(
+ _EXAMPLE_IR,
+ [ir_data.NumericConstant],
+ _record_constant,
+ parameters={"constant_list": constants},
+ )
+ self.assertEqual(
+ _count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320, 1, 1, 1, 64]),
+ _count_entries(constants),
+ )
- def test_filter_on_type_in_type(self):
- constants = []
- traverse_ir.fast_traverse_ir_top_down(
- _EXAMPLE_IR,
- [ir_data.Function, ir_data.Expression, ir_data.NumericConstant],
- _record_constant,
- parameters={"constant_list": constants})
- self.assertEqual([1, 1], constants)
+ def test_filter_on_type_in_type(self):
+ constants = []
+ traverse_ir.fast_traverse_ir_top_down(
+ _EXAMPLE_IR,
+ [ir_data.Function, ir_data.Expression, ir_data.NumericConstant],
+ _record_constant,
+ parameters={"constant_list": constants},
+ )
+ self.assertEqual([1, 1], constants)
- def test_filter_on_type_star_type(self):
- struct_constants = []
- traverse_ir.fast_traverse_ir_top_down(
- _EXAMPLE_IR, [ir_data.Structure, ir_data.NumericConstant],
- _record_constant,
- parameters={"constant_list": struct_constants})
- self.assertEqual(_count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320]),
- _count_entries(struct_constants))
- enum_constants = []
- traverse_ir.fast_traverse_ir_top_down(
- _EXAMPLE_IR, [ir_data.Enum, ir_data.NumericConstant], _record_constant,
- parameters={"constant_list": enum_constants})
- self.assertEqual(_count_entries([1, 1, 1]), _count_entries(enum_constants))
+ def test_filter_on_type_star_type(self):
+ struct_constants = []
+ traverse_ir.fast_traverse_ir_top_down(
+ _EXAMPLE_IR,
+ [ir_data.Structure, ir_data.NumericConstant],
+ _record_constant,
+ parameters={"constant_list": struct_constants},
+ )
+ self.assertEqual(
+ _count_entries([0, 8, 8, 8, 16, 24, 32, 16, 32, 320]),
+ _count_entries(struct_constants),
+ )
+ enum_constants = []
+ traverse_ir.fast_traverse_ir_top_down(
+ _EXAMPLE_IR,
+ [ir_data.Enum, ir_data.NumericConstant],
+ _record_constant,
+ parameters={"constant_list": enum_constants},
+ )
+ self.assertEqual(_count_entries([1, 1, 1]), _count_entries(enum_constants))
- def test_filter_on_not_type(self):
- notstruct_constants = []
- traverse_ir.fast_traverse_ir_top_down(
- _EXAMPLE_IR, [ir_data.NumericConstant], _record_constant,
- skip_descendants_of=(ir_data.Structure,),
- parameters={"constant_list": notstruct_constants})
- self.assertEqual(_count_entries([1, 1, 1, 64]),
- _count_entries(notstruct_constants))
+ def test_filter_on_not_type(self):
+ notstruct_constants = []
+ traverse_ir.fast_traverse_ir_top_down(
+ _EXAMPLE_IR,
+ [ir_data.NumericConstant],
+ _record_constant,
+ skip_descendants_of=(ir_data.Structure,),
+ parameters={"constant_list": notstruct_constants},
+ )
+ self.assertEqual(
+ _count_entries([1, 1, 1, 64]), _count_entries(notstruct_constants)
+ )
- def test_field_is_populated(self):
- constants = []
- traverse_ir.fast_traverse_ir_top_down(
- _EXAMPLE_IR, [ir_data.Field, ir_data.NumericConstant],
- _record_field_name_and_constant,
- parameters={"constant_list": constants})
- self.assertEqual(_count_entries([
- ("field1", 0), ("field1", 8), ("field2", 8), ("field2", 8),
- ("field2", 16), ("bar_field1", 24), ("bar_field1", 32),
- ("bar_field2", 16), ("bar_field2", 32), ("bar_field2", 320)
- ]), _count_entries(constants))
+ def test_field_is_populated(self):
+ constants = []
+ traverse_ir.fast_traverse_ir_top_down(
+ _EXAMPLE_IR,
+ [ir_data.Field, ir_data.NumericConstant],
+ _record_field_name_and_constant,
+ parameters={"constant_list": constants},
+ )
+ self.assertEqual(
+ _count_entries(
+ [
+ ("field1", 0),
+ ("field1", 8),
+ ("field2", 8),
+ ("field2", 8),
+ ("field2", 16),
+ ("bar_field1", 24),
+ ("bar_field1", 32),
+ ("bar_field2", 16),
+ ("bar_field2", 32),
+ ("bar_field2", 320),
+ ]
+ ),
+ _count_entries(constants),
+ )
- def test_file_name_is_populated(self):
- constants = []
- traverse_ir.fast_traverse_ir_top_down(
- _EXAMPLE_IR, [ir_data.NumericConstant], _record_file_name_and_constant,
- parameters={"constant_list": constants})
- self.assertEqual(_count_entries([
- ("t.emb", 0), ("t.emb", 8), ("t.emb", 8), ("t.emb", 8), ("t.emb", 16),
- ("t.emb", 24), ("t.emb", 32), ("t.emb", 16), ("t.emb", 32),
- ("t.emb", 320), ("t.emb", 1), ("t.emb", 1), ("t.emb", 1), ("", 64)
- ]), _count_entries(constants))
+ def test_file_name_is_populated(self):
+ constants = []
+ traverse_ir.fast_traverse_ir_top_down(
+ _EXAMPLE_IR,
+ [ir_data.NumericConstant],
+ _record_file_name_and_constant,
+ parameters={"constant_list": constants},
+ )
+ self.assertEqual(
+ _count_entries(
+ [
+ ("t.emb", 0),
+ ("t.emb", 8),
+ ("t.emb", 8),
+ ("t.emb", 8),
+ ("t.emb", 16),
+ ("t.emb", 24),
+ ("t.emb", 32),
+ ("t.emb", 16),
+ ("t.emb", 32),
+ ("t.emb", 320),
+ ("t.emb", 1),
+ ("t.emb", 1),
+ ("t.emb", 1),
+ ("", 64),
+ ]
+ ),
+ _count_entries(constants),
+ )
- def test_type_definition_is_populated(self):
- constants = []
- traverse_ir.fast_traverse_ir_top_down(
- _EXAMPLE_IR, [ir_data.NumericConstant], _record_kind_and_constant,
- parameters={"constant_list": constants})
- self.assertEqual(_count_entries([
- ("structure", 0), ("structure", 8), ("structure", 8), ("structure", 8),
- ("structure", 16), ("structure", 24), ("structure", 32),
- ("structure", 16), ("structure", 32), ("structure", 320),
- ("enumeration", 1), ("enumeration", 1), ("enumeration", 1),
- ("external", 64)
- ]), _count_entries(constants))
+ def test_type_definition_is_populated(self):
+ constants = []
+ traverse_ir.fast_traverse_ir_top_down(
+ _EXAMPLE_IR,
+ [ir_data.NumericConstant],
+ _record_kind_and_constant,
+ parameters={"constant_list": constants},
+ )
+ self.assertEqual(
+ _count_entries(
+ [
+ ("structure", 0),
+ ("structure", 8),
+ ("structure", 8),
+ ("structure", 8),
+ ("structure", 16),
+ ("structure", 24),
+ ("structure", 32),
+ ("structure", 16),
+ ("structure", 32),
+ ("structure", 320),
+ ("enumeration", 1),
+ ("enumeration", 1),
+ ("enumeration", 1),
+ ("external", 64),
+ ]
+ ),
+ _count_entries(constants),
+ )
- def test_keyword_args_dict_in_action(self):
- call_counts = {"populated": 0, "not": 0}
+ def test_keyword_args_dict_in_action(self):
+ call_counts = {"populated": 0, "not": 0}
- def check_field_is_populated(node, **kwargs):
- del node # Unused.
- self.assertTrue(kwargs["field"])
- call_counts["populated"] += 1
+ def check_field_is_populated(node, **kwargs):
+ del node # Unused.
+ self.assertTrue(kwargs["field"])
+ call_counts["populated"] += 1
- def check_field_is_not_populated(node, **kwargs):
- del node # Unused.
- self.assertFalse("field" in kwargs)
- call_counts["not"] += 1
+ def check_field_is_not_populated(node, **kwargs):
+ del node # Unused.
+ self.assertFalse("field" in kwargs)
+ call_counts["not"] += 1
- traverse_ir.fast_traverse_ir_top_down(
- _EXAMPLE_IR, [ir_data.Field, ir_data.Type], check_field_is_populated)
- self.assertEqual(7, call_counts["populated"])
+ traverse_ir.fast_traverse_ir_top_down(
+ _EXAMPLE_IR, [ir_data.Field, ir_data.Type], check_field_is_populated
+ )
+ self.assertEqual(7, call_counts["populated"])
- traverse_ir.fast_traverse_ir_top_down(
- _EXAMPLE_IR, [ir_data.Enum, ir_data.EnumValue],
- check_field_is_not_populated)
- self.assertEqual(2, call_counts["not"])
+ traverse_ir.fast_traverse_ir_top_down(
+ _EXAMPLE_IR, [ir_data.Enum, ir_data.EnumValue], check_field_is_not_populated
+ )
+ self.assertEqual(2, call_counts["not"])
- def test_pass_only_to_sub_nodes(self):
- constants = []
+ def test_pass_only_to_sub_nodes(self):
+ constants = []
- def pass_location_down(field):
- return {
- "location": (int(field.location.start.constant.value),
- int(field.location.size.constant.value))
- }
+ def pass_location_down(field):
+ return {
+ "location": (
+ int(field.location.start.constant.value),
+ int(field.location.size.constant.value),
+ )
+ }
- traverse_ir.fast_traverse_ir_top_down(
- _EXAMPLE_IR, [ir_data.NumericConstant],
- _record_location_parameter_and_constant,
- incidental_actions={ir_data.Field: pass_location_down},
- parameters={"constant_list": constants, "location": None})
- self.assertEqual(_count_entries([
- ((0, 8), 0), ((0, 8), 8), ((8, 16), 8), ((8, 16), 8), ((8, 16), 16),
- ((24, 32), 24), ((24, 32), 32), ((32, 320), 16), ((32, 320), 32),
- ((32, 320), 320), (None, 1), (None, 1), (None, 1), (None, 64)
- ]), _count_entries(constants))
+ traverse_ir.fast_traverse_ir_top_down(
+ _EXAMPLE_IR,
+ [ir_data.NumericConstant],
+ _record_location_parameter_and_constant,
+ incidental_actions={ir_data.Field: pass_location_down},
+ parameters={"constant_list": constants, "location": None},
+ )
+ self.assertEqual(
+ _count_entries(
+ [
+ ((0, 8), 0),
+ ((0, 8), 8),
+ ((8, 16), 8),
+ ((8, 16), 8),
+ ((8, 16), 16),
+ ((24, 32), 24),
+ ((24, 32), 32),
+ ((32, 320), 16),
+ ((32, 320), 32),
+ ((32, 320), 320),
+ (None, 1),
+ (None, 1),
+ (None, 1),
+ (None, 64),
+ ]
+ ),
+ _count_entries(constants),
+ )
if __name__ == "__main__":
- unittest.main()
+ unittest.main()
diff --git a/doc/__init__.py b/doc/__init__.py
index 2c31d84..086a24e 100644
--- a/doc/__init__.py
+++ b/doc/__init__.py
@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
diff --git a/runtime/cpp/generators/all_known.py b/runtime/cpp/generators/all_known.py
index 5efa03f..c3c1108 100644
--- a/runtime/cpp/generators/all_known.py
+++ b/runtime/cpp/generators/all_known.py
@@ -22,7 +22,8 @@
OVERLOADS = 64
# Copyright header in the generated code complies with Google policies.
-print("""// Copyright 2020 Google LLC
+print(
+ """// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -37,18 +38,27 @@
// limitations under the License.
// GENERATED CODE. DO NOT EDIT. REGENERATE WITH
-// runtime/cpp/generators/all_known.py""")
+// runtime/cpp/generators/all_known.py"""
+)
for i in range(1, OVERLOADS + 1):
- print("""
+ print(
+ """
template <{}>
inline constexpr bool AllKnown({}) {{
return {};
}}""".format(
- ", ".join(["typename T{}".format(n) for n in range(i)] +
- (["typename... RestT"] if i == OVERLOADS else [])),
- ", ".join(["T{} v{}".format(n, n) for n in range(i)] +
- (["RestT... rest"] if i == OVERLOADS else [])),
- " && ".join(["v{}.Known()".format(n) for n in range(i)] +
- (["AllKnown(rest...)"] if i == OVERLOADS else []))))
-
+ ", ".join(
+ ["typename T{}".format(n) for n in range(i)]
+ + (["typename... RestT"] if i == OVERLOADS else [])
+ ),
+ ", ".join(
+ ["T{} v{}".format(n, n) for n in range(i)]
+ + (["RestT... rest"] if i == OVERLOADS else [])
+ ),
+ " && ".join(
+ ["v{}.Known()".format(n) for n in range(i)]
+ + (["AllKnown(rest...)"] if i == OVERLOADS else [])
+ ),
+ )
+ )
diff --git a/runtime/cpp/generators/maximum_operation_do.py b/runtime/cpp/generators/maximum_operation_do.py
index 9609873..b0087a3 100644
--- a/runtime/cpp/generators/maximum_operation_do.py
+++ b/runtime/cpp/generators/maximum_operation_do.py
@@ -24,7 +24,8 @@
OVERLOADS = 64
# Copyright header in the generated code complies with Google policies.
-print("""// Copyright 2020 Google LLC
+print(
+ """// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -39,17 +40,21 @@
// limitations under the License.
// GENERATED CODE. DO NOT EDIT. REGENERATE WITH
-// runtime/cpp/generators/maximum_operation_do.py""")
+// runtime/cpp/generators/maximum_operation_do.py"""
+)
for i in range(5, OVERLOADS + 1):
- print("""
+ print(
+ """
template <typename T>
static inline constexpr T Do({0}) {{
return Do(Do({1}), Do({2}));
}}""".strip().format(
- ", ".join(["T v{}".format(n) for n in range(i)]),
- ", ".join(["v{}".format(n) for n in range(i // 2)]),
- ", ".join(["v{}".format(n) for n in range(i // 2, i)])))
+ ", ".join(["T v{}".format(n) for n in range(i)]),
+ ", ".join(["v{}".format(n) for n in range(i // 2)]),
+ ", ".join(["v{}".format(n) for n in range(i // 2, i)]),
+ )
+ )
# The "more than OVERLOADS arguments" overload uses a variadic template to
# handle the remaining arguments, even though all arguments should have the
@@ -59,11 +64,14 @@
#
# This also uses one explicit argument, rest0, to ensure that it does not get
# confused with the last non-variadic overload.
-print("""
+print(
+ """
template <typename T, typename... RestT>
static inline constexpr T Do({0}, T rest0, RestT... rest) {{
return Do(Do({1}), Do(rest0, rest...));
}}""".format(
- ", ".join(["T v{}".format(n) for n in range(OVERLOADS)]),
- ", ".join(["v{}".format(n) for n in range(OVERLOADS)]),
- OVERLOADS))
+ ", ".join(["T v{}".format(n) for n in range(OVERLOADS)]),
+ ", ".join(["v{}".format(n) for n in range(OVERLOADS)]),
+ OVERLOADS,
+ )
+)
diff --git a/testdata/__init__.py b/testdata/__init__.py
index 2c31d84..086a24e 100644
--- a/testdata/__init__.py
+++ b/testdata/__init__.py
@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
diff --git a/testdata/format/__init__.py b/testdata/format/__init__.py
index 2c31d84..086a24e 100644
--- a/testdata/format/__init__.py
+++ b/testdata/format/__init__.py
@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
diff --git a/testdata/golden/__init__.py b/testdata/golden/__init__.py
index 2c31d84..086a24e 100644
--- a/testdata/golden/__init__.py
+++ b/testdata/golden/__init__.py
@@ -11,4 +11,3 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-