Merge pull request #149 from EricRahm/use_dataclass
Convert `ir_data` to a `dataclass`
diff --git a/compiler/front_end/constraints_test.py b/compiler/front_end/constraints_test.py
index b214107..eac1232 100644
--- a/compiler/front_end/constraints_test.py
+++ b/compiler/front_end/constraints_test.py
@@ -19,6 +19,7 @@
from compiler.front_end import constraints
from compiler.front_end import glue
from compiler.util import error
+from compiler.util import ir_data_utils
from compiler.util import ir_util
from compiler.util import test_util
@@ -38,7 +39,13 @@
def test_error_on_missing_inner_array_size(self):
ir = _make_ir_from_emb("struct Foo:\n"
" 0 [+1] UInt:8[][1] one_byte\n")
- error_array = ir.module[0].type[0].structure.field[0].type.array_type
+ # There is a latent issue here where the source location reported in this
+ # error is using a default value of 0:0. An issue is filed at
+ # https://github.com/google/emboss/issues/153 for further investigation.
+ # In the meantime we use `ir_data_utils.reader` to mimic this legacy
+ # behavior.
+ error_array = ir_data_utils.reader(
+ ir.module[0].type[0].structure.field[0].type.array_type)
self.assertEqual([[
error.error(
"m.emb",
diff --git a/compiler/front_end/module_ir_test.py b/compiler/front_end/module_ir_test.py
index 57d5f4c..4eac8f1 100644
--- a/compiler/front_end/module_ir_test.py
+++ b/compiler/front_end/module_ir_test.py
@@ -24,6 +24,7 @@
from compiler.front_end import parser
from compiler.front_end import tokenizer
from compiler.util import ir_data
+from compiler.util import ir_data_fields
from compiler.util import ir_data_utils
from compiler.util import test_util
@@ -4089,29 +4090,28 @@
child_end = None
# Only check the source_location value if this proto message actually has a
# source_location field.
- if "source_location" in proto.raw_fields:
+ if proto.HasField("source_location"):
errors.extend(_check_source_location(proto.source_location,
path + "source_location",
min_start, max_end))
child_start = proto.source_location.start
child_end = proto.source_location.end
- for name, spec in proto.field_specs.items():
+ for name, spec in ir_data_fields.field_specs(proto).items():
if name == "source_location":
continue
if not proto.HasField(name):
continue
field_path = "{}{}".format(path, name)
- if isinstance(spec, ir_pb2.Repeated):
- if issubclass(spec.type, ir_pb2.Message):
+ if spec.is_dataclass:
+ if spec.is_sequence:
index = 0
for i in getattr(proto, name):
item_path = "{}[{}]".format(field_path, index)
index += 1
errors.extend(
_check_all_source_locations(i, item_path, child_start, child_end))
- else:
- if issubclass(spec.type, ir_data.Message):
+ else:
errors.extend(_check_all_source_locations(getattr(proto, name),
field_path, child_start,
child_end))
diff --git a/compiler/front_end/symbol_resolver.py b/compiler/front_end/symbol_resolver.py
index 5990f13..9328e8f 100644
--- a/compiler/front_end/symbol_resolver.py
+++ b/compiler/front_end/symbol_resolver.py
@@ -293,7 +293,7 @@
name = reference.source_name[0].text
for scope in visible_scopes:
scoped_table = table[scope.module_file]
- for path_element in scope.object_path:
+ for path_element in scope.object_path or []:
scoped_table = scoped_table[path_element]
if (name in scoped_table and
(scope == current_scope or
diff --git a/compiler/front_end/synthetics.py b/compiler/front_end/synthetics.py
index 42b3cff..73be294 100644
--- a/compiler/front_end/synthetics.py
+++ b/compiler/front_end/synthetics.py
@@ -29,11 +29,11 @@
return
if hasattr(proto, "source_location"):
ir_data_utils.builder(proto).source_location.is_synthetic = True
- for name, value in proto.raw_fields.items():
- if name != "source_location":
- if isinstance(value, ir_data.TypedScopedList):
- for i in range(len(value)):
- _mark_as_synthetic(value[i])
+ for spec, value in ir_data_utils.get_set_fields(proto):
+ if spec.name != "source_location" and spec.is_dataclass:
+ if spec.is_sequence:
+ for i in value:
+ _mark_as_synthetic(i)
else:
_mark_as_synthetic(value)
diff --git a/compiler/util/BUILD b/compiler/util/BUILD
index 5946dcb..ab53344 100644
--- a/compiler/util/BUILD
+++ b/compiler/util/BUILD
@@ -25,8 +25,22 @@
name = "ir_data",
srcs = [
"ir_data.py",
+ "ir_data_fields.py",
"ir_data_utils.py",
],
+ deps = [],
+)
+
+py_test(
+ name = "ir_data_fields_test",
+ srcs = ["ir_data_fields_test.py"],
+ deps = [":ir_data"],
+)
+
+py_test(
+ name = "ir_data_utils_test",
+ srcs = ["ir_data_utils_test.py"],
+ deps = [":expression_parser", ":ir_data"],
)
py_library(
diff --git a/compiler/util/error.py b/compiler/util/error.py
index c85d8fc..e408c71 100644
--- a/compiler/util/error.py
+++ b/compiler/util/error.py
@@ -36,6 +36,7 @@
]
"""
+from compiler.util import ir_data_utils
from compiler.util import parser_types
# Error levels; represented by the strings that will be included in messages.
@@ -65,20 +66,26 @@
BOLD = "\033[0;1m"
RESET = "\033[0m"
+def _copy(location):
+ location = ir_data_utils.copy(location)
+ if not location:
+ location = parser_types.make_location((0,0), (0,0))
+ return location
+
def error(source_file, location, message):
"""Returns an object representing an error message."""
- return _Message(source_file, location, ERROR, message)
+ return _Message(source_file, _copy(location), ERROR, message)
def warn(source_file, location, message):
"""Returns an object representing a warning."""
- return _Message(source_file, location, WARNING, message)
+ return _Message(source_file, _copy(location), WARNING, message)
def note(source_file, location, message):
"""Returns and object representing an informational note."""
- return _Message(source_file, location, NOTE, message)
+ return _Message(source_file, _copy(location), NOTE, message)
class _Message(object):
diff --git a/compiler/util/ir_data.py b/compiler/util/ir_data.py
index 6ee0b47..624b09a 100644
--- a/compiler/util/ir_data.py
+++ b/compiler/util/ir_data.py
@@ -14,411 +14,170 @@
"""Intermediate representation (IR) for Emboss.
-This was originally a Google Protocol Buffer file, but as of 2019 it turns
-out that a) the public Google Python Protocol Buffer implementation is
-extremely slow, and b) all the ways of getting Bazel+Python+Protocol Buffers
-to play nice are hacky and fragile.
-
-Thus, this file, which presents a similar-enough interface that the rest of
-Emboss can use it with minimal changes.
-
-Protobufs have a really, really strange, un-Pythonic interface, with tricky
-implicit semantics -- mostly around magically instantiating protos when you
-assign to some deeply-nested field. I (bolms@) would *strongly* prefer to
-have a more explicit interface, but don't (currently) have time to refactor
-everything that touches the IR (i.e., the entire compiler).
+This is limited to purely data and type annotations.
"""
+import dataclasses
import enum
-import json
import sys
+from typing import ClassVar, Optional
+
+from compiler.util import ir_data_fields
-if sys.version_info[0] == 2:
- _Text = unicode
- _text_types = (unicode, str)
- _Int = long
- _int_types = (int, long)
-else:
- _Text = str
- _text_types = (str,)
- _Int = int
- _int_types = (int,)
+@dataclasses.dataclass
+class Message:
+ """Base class for IR data objects.
+ Historically protocol buffers were used for serializing this data which has
+ led to some legacy naming conventions and references. In particular this
+ class is named `Message` in the sense of a protocol buffer message,
+ indicating that it is intended to just be data that is used by other higher
+ level services.
-_BASIC_TYPES = _text_types + _int_types + (bool,)
-
-
-class Optional(object):
- """Property implementation for "optional"-like fields."""
-
- def __init__(self, type_, oneof=None, decode_names=None):
- """Creates a Proto "optional"-like data member.
-
- Args:
- type_: The type of the field; e.g., _Int or _Text.
- oneof: If set, the name of the proto-like "oneof" that this field is a
- member of. Within a structure, at most one field of a particular
- "oneof" may be set at a time; setting a member of the "oneof" will
- clear any other member that might be set.
- decode_names: An optional callable that takes a str and returns a
- value of type type_; allows strs to be used to set "enums" using
- their symbolic names.
- """
- self._type = type_
- self._oneof = oneof
- self._decode_names = decode_names
-
- def __get__(self, obj, type_=None):
- result = obj.raw_fields.get(self.name, None)
- if result is not None:
- return result
- if self.type in _BASIC_TYPES:
- return self._type()
- result = self._type()
-
- def on_write():
- self._set_value(obj, result)
-
- result.set_on_write(on_write)
- return result
-
- def __set__(self, obj, value):
- if issubclass(self._type, _BASIC_TYPES):
- self.set(obj, value)
- else:
- raise AttributeError("Cannot set {} (type {}) for type {}".format(
- value, value.__class__, self._type))
-
- def _set_value(self, obj, value):
- if self._oneof is not None:
- current = obj.oneofs.get(self._oneof)
- if current in obj.raw_fields:
- del obj.raw_fields[current]
- obj.oneofs[self._oneof] = self.name
- obj.raw_fields[self.name] = value
- obj.on_write()
-
- def set(self, obj, value):
- """Sets the given value to the property."""
-
- if value is None:
- return
- if isinstance(value, dict):
- self._set_value(obj, self._type(**value))
- elif isinstance(value, _Text) and self._decode_names:
- self._set_value(obj, self._type(self._decode_names(value)))
- elif isinstance(value, _Text) and issubclass(self._type, enum.Enum):
- self._set_value(obj, getattr(self._type, value))
- elif (not isinstance(value, self._type) and
- not (self._type == _Int and isinstance(value, _int_types)) and
- not (self._type == _Text and isinstance(value, _text_types)) and
- not (issubclass(self._type, enum.Enum) and isinstance(value, _int_types))):
- raise AttributeError("Cannot set {} (type {}) for type {}".format(
- value, value.__class__, self._type))
- elif issubclass(self._type, Message):
- self._set_value(obj, self._type(**value.raw_fields))
- else:
- self._set_value(obj, self._type(value))
-
- def resolve_type(self):
- if isinstance(self._type, type(lambda: None)):
- self._type = self._type()
-
- @property
- def type(self):
- return self._type
-
-
-class TypedScopedList(object):
- """A list with typechecking that notifies its parent when written to.
-
- TypedScopedList implements roughly the same semantics as the value of a
- Protobuf repeated field. In particular, it checks that any values added
- to the list are of the correct type, and it calls the on_write callable
- when a value is added to the list, in order to implement the Protobuf
- "autovivification" semantics.
+ There are some other legacy idioms leftover from the protocol buffer-based
+ definition such as support for "oneof" and optional fields.
"""
- def __init__(self, type_, on_write=lambda: None):
- self._type = type_
- self._list = []
- self._on_write = on_write
+ IR_DATACLASS: ClassVar[object] = object()
+ field_specs: ClassVar[ir_data_fields.FilteredIrFieldSpecs]
- def __iter__(self):
- return iter(self._list)
+ def __post_init__(self):
+ """Called by dataclass subclasses after init.
- def __delitem__(self, key):
- del self._list[key]
-
- def __getitem__(self, key):
- return self._list[key]
-
- def extend(self, values):
- """list-like extend()."""
-
- for value in values:
- if isinstance(value, dict):
- self._list.append(self._type(**value))
- elif (not isinstance(value, self._type) and
- not (self._type == _Int and isinstance(value, _int_types)) and
- not (self._type == _Text and isinstance(value, _text_types))):
- raise TypeError(
- "Needed {}, got {} ({!r})".format(
- self._type, value.__class__, value))
+ Post-processes any lists passed in to use our custom list type.
+ """
+ # Convert any lists passed in to CopyValuesList
+ for spec in self.field_specs.sequence_field_specs:
+ cur_val = getattr(self, spec.name)
+ if isinstance(cur_val, ir_data_fields.TemporaryCopyValuesList):
+ copy_val = cur_val.temp_list
else:
- if self._type in _BASIC_TYPES:
- self._list.append(self._type(value))
- else:
- self._list.append(self._type(**value.raw_fields))
- self._on_write()
+ copy_val = ir_data_fields.CopyValuesList(spec.data_type)
+ if cur_val:
+ copy_val.shallow_copy(cur_val)
+ setattr(self, spec.name, copy_val)
- def __repr__(self):
- return repr(self._list)
-
- def __len__(self):
- return len(self._list)
-
- def __eq__(self, other):
- return ((self.__class__ == other.__class__ and
- self._list == other._list) or # pylint:disable=protected-access
- (isinstance(other, list) and self._list == other))
-
- def __ne__(self, other):
- return not (self == other) # pylint:disable=superfluous-parens
-
-
-class Repeated(object):
- """A proto-"repeated"-like property."""
-
- def __init__(self, type_):
- self._type = type_
-
- def __get__(self, obj, type_=None):
- return obj.raw_fields[self.name]
-
- def __set__(self, obj, value):
- raise AttributeError("Cannot set {}".format(self.name))
-
- def set(self, obj, values):
- typed_list = obj.raw_fields[self.name]
- if not isinstance(values, (list, TypedScopedList)):
- raise TypeError("Cannot initialize repeated field {} from {}".format(
- self.name, values.__class__))
- del typed_list[:]
- typed_list.extend(values)
-
- def resolve_type(self):
- if isinstance(self._type, type(lambda: None)):
- self._type = self._type()
-
- @property
- def type(self):
- return self._type
-
-
-_deferred_specs = []
-
-
-def message(cls):
- # TODO(bolms): move this into __init_subclass__ after dropping Python 2
- # support.
- _deferred_specs.append(cls)
- return cls
-
-
-class Message(object):
- """Base class for proto "message"-like objects."""
-
- def __init__(self, **field_values):
- self.oneofs = {}
- self._on_write = lambda: None
- self._initialize_raw_fields_from(field_values)
-
- def _initialize_raw_fields_from(self, field_values):
- self.raw_fields = {}
- for name, type_ in self.repeated_fields.items():
- self.raw_fields[name] = TypedScopedList(type_, self.on_write)
- for k, v in field_values.items():
- spec = self.field_specs.get(k)
- if spec is None:
- raise AttributeError("No field {} on {}.".format(
- k, self.__class__.__name__))
- spec.set(self, v)
-
- @classmethod
- def from_json(cls, text):
- as_dict = json.loads(text)
- return cls(**as_dict)
-
- def on_write(self):
- self._on_write()
- self._on_write = lambda: None
-
- def set_on_write(self, on_write):
- self._on_write = on_write
-
- def __eq__(self, other):
- return (self.__class__ == other.__class__ and
- self.raw_fields == other.raw_fields)
-
- # Non-PEP8 name to mimic the Google Protobuf interface.
- def CopyFrom(self, other): # pylint:disable=invalid-name
- if self.__class__ != other.__class__:
- raise TypeError("{} cannot CopyFrom {}".format(
- self.__class__.__name__, other.__class__.__name__))
- self._initialize_raw_fields_from(other.raw_fields)
- self.on_write()
+ # This hook adds a 15% overhead to end-to-end code generation in some cases
+ # so we guard it in a `__debug__` block. Users can opt-out of this check by
+ # running python with the `-O` flag, ie: `python3 -O ./embossc`.
+ if __debug__:
+ def __setattr__(self, name: str, value) -> None:
+ """Debug-only hook that adds basic type checking for ir_data fields."""
+ if spec := self.field_specs.all_field_specs.get(name):
+ if not (
+ # Check if it's the expected type
+ isinstance(value, spec.data_type) or
+ # Oneof fields are a special case
+ spec.is_oneof or
+ # Optional fields can be set to None
+ (spec.container is ir_data_fields.FieldContainer.OPTIONAL and
+ value is None) or
+ # Sequences can be a few variants of lists
+ (spec.is_sequence and
+ isinstance(value, (
+ list, ir_data_fields.TemporaryCopyValuesList,
+ ir_data_fields.CopyValuesList))) or
+ # An enum value can be an int
+ (spec.is_enum and isinstance(value, int))):
+ raise AttributeError(
+ f"Cannot set {value} (type {value.__class__}) for type"
+ "{spec.data_type}")
+ object.__setattr__(self, name, value)
# Non-PEP8 name to mimic the Google Protobuf interface.
def HasField(self, name): # pylint:disable=invalid-name
- return name in self.raw_fields
+ """Indicates if this class has the given field defined and it is set."""
+ return getattr(self, name, None) is not None
# Non-PEP8 name to mimic the Google Protobuf interface.
def WhichOneof(self, oneof_name): # pylint:disable=invalid-name
- return self.oneofs.get(oneof_name)
+ """Indicates which field has been set for the oneof value.
- def to_dict(self):
- """Converts the message to a dict."""
-
- result = {}
- for k, v in self.raw_fields.items():
- if isinstance(v, _BASIC_TYPES):
- result[k] = v
- elif isinstance(v, TypedScopedList):
- if v:
- # For compatibility with the proto world, empty lists are just
- # elided.
- result[k] = [
- item if isinstance(item, _BASIC_TYPES) else item.to_dict()
- for item in v
- ]
- else:
- result[k] = v.to_dict()
- return result
-
- def __repr__(self):
- return self.to_json(separators=(",", ":"), sort_keys=True)
-
- def to_json(self, *args, **kwargs):
- return json.dumps(self.to_dict(), *args, **kwargs)
-
- def __str__(self):
- return _Text(self.to_dict())
-
-
-def _initialize_deferred_specs():
- """Calls any lambdas in specs, to facilitate late binding.
-
- When two Message subclasses are mutually recursive, the standard way of
- referencing one of the classes will not work, because its name is not
- yet defined. E.g.:
-
- class A(Message):
- b = Optional(B)
-
- class B(Message):
- a = Optional(A)
-
- In this case, Python complains when trying to construct the class A,
- because it cannot resolve B.
-
- To accommodate this, Optional and Repeated will accept callables instead of
- types, like:
-
- class A(Message):
- b = Optional(lambda: B)
-
- class B(Message):
- a = Optional(A)
-
- Once all of the message classes have been defined, it is safe to go back and
- resolve all of the names by calling all of the lambdas that were used in place
- of types. This function just iterates through the message types, and asks
- their Optional and Repeated properties to call the lambdas that were used in
- place of types.
- """
-
- for cls in _deferred_specs:
- field_specs = {}
- repeated_fields = {}
- for k, v in cls.__dict__.items():
- if k.startswith("_"):
- continue
- if isinstance(v, (Optional, Repeated)):
- v.name = k
- v.resolve_type()
- field_specs[k] = v
- if isinstance(v, Repeated):
- repeated_fields[k] = v.type
- cls.field_specs = field_specs
- cls.repeated_fields = repeated_fields
+ Returns None if no field has been set.
+ """
+ for field_name, oneof in self.field_specs.oneof_mappings:
+ if oneof == oneof_name and self.HasField(field_name):
+ return field_name
+ return None
################################################################################
-# From here to (nearly) the end of the file are actual structure definitions.
+# From here to the end of the file are actual structure definitions.
-@message
+@dataclasses.dataclass
class Position(Message):
"""A zero-width position within a source file."""
- line = Optional(int) # Line (starts from 1).
- column = Optional(int) # Column (starts from 1).
+
+ line: int = 0
+ """Line (starts from 1)."""
+ column: int = 0
+ """Column (starts from 1)."""
-@message
+@dataclasses.dataclass
class Location(Message):
"""A half-open start:end range within a source file."""
- start = Optional(Position) # Beginning of the range.
- end = Optional(Position) # One column past the end of the range.
- # True if this Location is outside of the parent object's Location.
- is_disjoint_from_parent = Optional(bool)
+ start: Optional[Position] = None
+ """Beginning of the range"""
+ end: Optional[Position] = None
+ """One column past the end of the range."""
- # True if this Location's parent was synthesized, and does not directly
- # appear in the source file. The Emboss front end uses this field to cull
- # irrelevant error messages.
- is_synthetic = Optional(bool)
+ is_disjoint_from_parent: Optional[bool] = None
+ """True if this Location is outside of the parent object's Location."""
+
+ is_synthetic: Optional[bool] = None
+ """True if this Location's parent was synthesized, and does not directly
+ appear in the source file.
+
+ The Emboss front end uses this field to cull
+ irrelevant error messages.
+ """
-@message
+@dataclasses.dataclass
class Word(Message):
"""IR for a bare word in the source file.
This is used in NameDefinitions and References.
"""
- text = Optional(_Text)
- source_location = Optional(Location)
+ text: Optional[str] = None
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class String(Message):
"""IR for a string in the source file."""
- text = Optional(_Text)
- source_location = Optional(Location)
+
+ text: Optional[str] = None
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class Documentation(Message):
- text = Optional(_Text)
- source_location = Optional(Location)
+ text: Optional[str] = None
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class BooleanConstant(Message):
"""IR for a boolean constant."""
- value = Optional(bool)
- source_location = Optional(Location)
+
+ value: Optional[bool] = None
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class Empty(Message):
"""Placeholder message for automatic element counts for arrays."""
- source_location = Optional(Location)
+
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class NumericConstant(Message):
"""IR for any numeric constant."""
@@ -427,43 +186,59 @@
#
# TODO(bolms): switch back to int, and just use strings during
# serialization, now that we're free of proto.
- value = Optional(_Text)
- source_location = Optional(Location)
+ value: Optional[str] = None
+ source_location: Optional[Location] = None
class FunctionMapping(int, enum.Enum):
"""Enum of supported function types"""
+
UNKNOWN = 0
- ADDITION = 1 # +
- SUBTRACTION = 2 # -
- MULTIPLICATION = 3 # *
- EQUALITY = 4 # ==
- INEQUALITY = 5 # !=
- AND = 6 # &&
- OR = 7 # ||
- LESS = 8 # <
- LESS_OR_EQUAL = 9 # <=
- GREATER = 10 # >
- GREATER_OR_EQUAL = 11 # >=
- CHOICE = 12 # ?:
- MAXIMUM = 13 # $max()
- PRESENCE = 14 # $present()
- UPPER_BOUND = 15 # $upper_bound()
- LOWER_BOUND = 16 # $lower_bound()
+ ADDITION = 1
+ """`+`"""
+ SUBTRACTION = 2
+ """`-`"""
+ MULTIPLICATION = 3
+ """`*`"""
+ EQUALITY = 4
+ """`==`"""
+ INEQUALITY = 5
+ """`!=`"""
+ AND = 6
+ """`&&`"""
+ OR = 7
+ """`||`"""
+ LESS = 8
+ """`<`"""
+ LESS_OR_EQUAL = 9
+ """`<=`"""
+ GREATER = 10
+ """`>`"""
+ GREATER_OR_EQUAL = 11
+ """`>=`"""
+ CHOICE = 12
+ """`?:`"""
+ MAXIMUM = 13
+ """`$max()`"""
+ PRESENCE = 14
+ """`$present()`"""
+ UPPER_BOUND = 15
+ """`$upper_bound()`"""
+ LOWER_BOUND = 16
+ """`$lower_bound()`"""
-@message
+@dataclasses.dataclass
class Function(Message):
"""IR for a single function (+, -, *, ==, $max, etc.) in an expression."""
- # pylint:disable=undefined-variable
- function = Optional(FunctionMapping)
- args = Repeated(lambda: Expression)
- function_name = Optional(Word)
- source_location = Optional(Location)
+ function: Optional[FunctionMapping] = None
+ args: list["Expression"] = ir_data_fields.list_field(lambda: Expression)
+ function_name: Optional[Word] = None
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class CanonicalName(Message):
"""CanonicalName is the unique, absolute name for some object.
@@ -472,32 +247,37 @@
Foo"), and in references to objects (a field of type "Foo").
"""
- # The module_file is the Module.source_file_name of the Module in which this
- # object's definition appears. Note that the Prelude always has a
- # Module.source_file_name of "", and thus references to Prelude names will
- # have module_file == "".
- module_file = Optional(_Text)
+ module_file: str = ir_data_fields.str_field()
+ """The module_file is the Module.source_file_name of the Module in which this
+ object's definition appears.
- # The object_path is the canonical path to the object definition within its
- # module file. For example, the field "bar" would have an object path of
- # ["Foo", "bar"]:
- #
- # struct Foo:
- # 0:3 UInt bar
- #
- #
- # The enumerated name "BOB" would have an object path of ["Baz", "Qux",
- # "BOB"]:
- #
- # struct Baz:
- # 0:3 Qux qux
- #
- # enum Qux:
- # BOB = 0
- object_path = Repeated(_Text)
+ Note that the Prelude always has a Module.source_file_name of "", and thus
+ references to Prelude names will have module_file == "".
+ """
+
+ object_path: list[str] = ir_data_fields.list_field(str)
+ """The object_path is the canonical path to the object definition within its
+ module file.
+
+ For example, the field "bar" would have an object path of
+ ["Foo", "bar"]:
+
+ struct Foo:
+ 0:3 UInt bar
-@message
+ The enumerated name "BOB" would have an object path of ["Baz", "Qux",
+ "BOB"]:
+
+ struct Baz:
+ 0:3 Qux qux
+
+ enum Qux:
+ BOB = 0
+ """
+
+
+@dataclasses.dataclass
class NameDefinition(Message):
"""NameDefinition is IR for the name of an object, within the object.
@@ -505,26 +285,31 @@
name.
"""
- # The name, as directly generated from the source text. name.text will
- # match the last element of canonical_name.object_path. Note that in some
- # cases, the exact string in name.text may not appear in the source text.
- name = Optional(Word)
+ name: Optional[Word] = None
+ """The name, as directly generated from the source text.
- # The CanonicalName that will appear in References. This field is
- # technically redundant: canonical_name.module_file should always match the
- # source_file_name of the enclosing Module, and canonical_name.object_path
- # should always match the names of parent nodes.
- canonical_name = Optional(CanonicalName)
+ name.text will match the last element of canonical_name.object_path. Note
+ that in some cases, the exact string in name.text may not appear in the
+ source text.
+ """
- # If true, indicates that this is an automatically-generated name, which
- # should not be visible outside of its immediate namespace.
- is_anonymous = Optional(bool)
+ canonical_name: Optional[CanonicalName] = None
+ """The CanonicalName that will appear in References.
+ This field is technically redundant: canonical_name.module_file should always
+ match the source_file_name of the enclosing Module, and
+ canonical_name.object_path should always match the names of parent nodes.
+ """
- # The location of this NameDefinition in source code.
- source_location = Optional(Location)
+ is_anonymous: Optional[bool] = None
+ """If true, indicates that this is an automatically-generated name, which
+ should not be visible outside of its immediate namespace.
+ """
+
+ source_location: Optional[Location] = None
+ """The location of this NameDefinition in source code."""
-@message
+@dataclasses.dataclass
class Reference(Message):
"""A Reference holds the canonical name of something defined elsewhere.
@@ -542,30 +327,37 @@
what appears in the .emb.
"""
- # The canonical name of the object being referred to. This name should be
- # used to find the object in the IR.
- canonical_name = Optional(CanonicalName)
+ canonical_name: Optional[CanonicalName] = None
+ """The canonical name of the object being referred to.
- # The source_name is the name the user entered in the source file; it could
- # be either relative or absolute, and may be an alias (and thus not match
- # any part of the canonical_name). Back ends should use canonical_name for
- # name lookup, and reserve source_name for error messages.
- source_name = Repeated(Word)
+ This name should be used to find the object in the IR.
+ """
- # If true, then symbol resolution should only look at local names when
- # resolving source_name. This is used so that the names of inline types
- # aren't "ambiguous" if there happens to be another type with the same name
- # at a parent scope.
- is_local_name = Optional(bool)
+ source_name: list[Word] = ir_data_fields.list_field(Word)
+ """The source_name is the name the user entered in the source file.
+
+ The source_name could be either relative or absolute, and may be an alias
+ (and thus not match any part of the canonical_name). Back ends should use
+ canonical_name for name lookup, and reserve source_name for error messages.
+ """
+
+ is_local_name: Optional[bool] = None
+ """If true, then symbol resolution should only look at local names when
+ resolving source_name.
+
+ This is used so that the names of inline types aren't "ambiguous" if there
+ happens to be another type with the same name at a parent scope.
+ """
# TODO(bolms): Allow absolute paths starting with ".".
- # Note that this is the source_location of the *Reference*, not of the
- # object to which it refers.
- source_location = Optional(Location)
+ source_location: Optional[Location] = None
+ """Note that this is the source_location of the *Reference*, not of the
+ object to which it refers.
+ """
-@message
+@dataclasses.dataclass
class FieldReference(Message):
"""IR for a "field" or "field.sub.subsub" reference in an expression.
@@ -608,16 +400,16 @@
# TODO(bolms): Make the above change before declaring the IR to be "stable".
- path = Repeated(Reference)
- source_location = Optional(Location)
+ path: list[Reference] = ir_data_fields.list_field(Reference)
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class OpaqueType(Message):
pass
-@message
+@dataclasses.dataclass
class IntegerType(Message):
"""Type of an integer expression."""
@@ -639,8 +431,21 @@
# the value from C's '%' operator when the dividend is negative: in C, -7 %
# 4 == -3, but the modular_value here would be 1. Python uses modulus: in
# Python, -7 % 4 == 1.
- modulus = Optional(_Text)
- modular_value = Optional(_Text)
+ modulus: Optional[str] = None
+ """The modulus portion of the modular congruence of an integer expression.
+
+ The modulus may be the special value "infinity" to indicate that the
+ expression's value is exactly modular_value; otherwise, it should be a
+ positive integer.
+
+ A modulus of 1 places no constraints on the value.
+ """
+ modular_value: Optional[str] = None
+ """ The modular_value portion of the modular congruence of an integer expression.
+
+ The modular_value should always be a nonnegative integer that is smaller
+ than the modulus.
+ """
# The minimum and maximum values of an integer are tracked and checked so
# that Emboss can implement reliable arithmetic with no operations
@@ -657,30 +462,30 @@
# Expression may only be evaluated during compilation; the back end should
# never need to compile such an expression into the target language (e.g.,
# C++).
- minimum_value = Optional(_Text)
- maximum_value = Optional(_Text)
+ minimum_value: Optional[str] = None
+ maximum_value: Optional[str] = None
-@message
+@dataclasses.dataclass
class BooleanType(Message):
- value = Optional(bool)
+ value: Optional[bool] = None
-@message
+@dataclasses.dataclass
class EnumType(Message):
- name = Optional(Reference)
- value = Optional(_Text)
+ name: Optional[Reference] = None
+ value: Optional[str] = None
-@message
+@dataclasses.dataclass
class ExpressionType(Message):
- opaque = Optional(OpaqueType, "type")
- integer = Optional(IntegerType, "type")
- boolean = Optional(BooleanType, "type")
- enumeration = Optional(EnumType, "type")
+ opaque: Optional[OpaqueType] = ir_data_fields.oneof_field("type")
+ integer: Optional[IntegerType] = ir_data_fields.oneof_field("type")
+ boolean: Optional[BooleanType] = ir_data_fields.oneof_field("type")
+ enumeration: Optional[EnumType] = ir_data_fields.oneof_field("type")
-@message
+@dataclasses.dataclass
class Expression(Message):
"""IR for an expression.
@@ -689,68 +494,81 @@
other Expressions (function).
"""
- constant = Optional(NumericConstant, "expression")
- constant_reference = Optional(Reference, "expression")
- function = Optional(Function, "expression")
- field_reference = Optional(FieldReference, "expression")
- boolean_constant = Optional(BooleanConstant, "expression")
- builtin_reference = Optional(Reference, "expression")
+ constant: Optional[NumericConstant] = ir_data_fields.oneof_field("expression")
+ constant_reference: Optional[Reference] = ir_data_fields.oneof_field(
+ "expression"
+ )
+ function: Optional[Function] = ir_data_fields.oneof_field("expression")
+ field_reference: Optional[FieldReference] = ir_data_fields.oneof_field(
+ "expression"
+ )
+ boolean_constant: Optional[BooleanConstant] = ir_data_fields.oneof_field(
+ "expression"
+ )
+ builtin_reference: Optional[Reference] = ir_data_fields.oneof_field(
+ "expression"
+ )
- type = Optional(ExpressionType)
- source_location = Optional(Location)
+ type: Optional[ExpressionType] = None
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class ArrayType(Message):
"""IR for an array type ("Int:8[12]" or "Message[2]" or "UInt[3][2]")."""
- base_type = Optional(lambda: Type)
- element_count = Optional(Expression, "size")
- automatic = Optional(Empty, "size")
+ base_type: Optional["Type"] = None
- source_location = Optional(Location)
+ element_count: Optional[Expression] = ir_data_fields.oneof_field("size")
+ automatic: Optional[Empty] = ir_data_fields.oneof_field("size")
+
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class AtomicType(Message):
"""IR for a non-array type ("UInt" or "Foo(Version.SIX)")."""
- reference = Optional(Reference)
- runtime_parameter = Repeated(Expression)
- source_location = Optional(Location)
+
+ reference: Optional[Reference] = None
+ runtime_parameter: list[Expression] = ir_data_fields.list_field(Expression)
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class Type(Message):
"""IR for a type reference ("UInt", "Int:8[12]", etc.)."""
- atomic_type = Optional(AtomicType, "type")
- array_type = Optional(ArrayType, "type")
- size_in_bits = Optional(Expression)
- source_location = Optional(Location)
+ atomic_type: Optional[AtomicType] = ir_data_fields.oneof_field("type")
+ array_type: Optional[ArrayType] = ir_data_fields.oneof_field("type")
+
+ size_in_bits: Optional[Expression] = None
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class AttributeValue(Message):
"""IR for a attribute value."""
+
# TODO(bolms): Make String a type of Expression, and replace
# AttributeValue with Expression.
- expression = Optional(Expression, "value")
- string_constant = Optional(String, "value")
+ expression: Optional[Expression] = ir_data_fields.oneof_field("value")
+ string_constant: Optional[String] = ir_data_fields.oneof_field("value")
- source_location = Optional(Location)
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class Attribute(Message):
"""IR for a [name = value] attribute."""
- name = Optional(Word)
- value = Optional(AttributeValue)
- back_end = Optional(Word)
- is_default = Optional(bool)
- source_location = Optional(Location)
+
+ name: Optional[Word] = None
+ value: Optional[AttributeValue] = None
+ back_end: Optional[Word] = None
+ is_default: Optional[bool] = None
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class WriteTransform(Message):
"""IR which defines an expression-based virtual field write scheme.
@@ -764,43 +582,48 @@
`function_body` and `x` for `destination`.
"""
- function_body = Optional(Expression)
- destination = Optional(FieldReference)
+ function_body: Optional[Expression] = None
+ destination: Optional[FieldReference] = None
-@message
+@dataclasses.dataclass
class WriteMethod(Message):
"""IR which defines the method used for writing to a virtual field."""
- # A physical Field can be written directly.
- physical = Optional(bool, "method")
+ physical: Optional[bool] = ir_data_fields.oneof_field("method")
+ """A physical Field can be written directly."""
- # A read_only Field cannot be written.
- read_only = Optional(bool, "method")
+ read_only: Optional[bool] = ir_data_fields.oneof_field("method")
+ """A read_only Field cannot be written."""
- # An alias is a direct, untransformed forward of another field; it can be
- # implemented by directly returning a reference to the aliased field.
- #
- # Aliases are the only kind of virtual field that may have an opaque type.
- alias = Optional(FieldReference, "method")
+ alias: Optional[FieldReference] = ir_data_fields.oneof_field("method")
+ """An alias is a direct, untransformed forward of another field; it can be
+ implemented by directly returning a reference to the aliased field.
- # A transform is a way of turning a logical value into a value which should
- # be written to another field: A virtual field like `let y = x + 1` would
- # have a transform WriteMethod to subtract 1 from the new `y` value, and
- # write that to `x`.
- transform = Optional(WriteTransform, "method")
+ Aliases are the only kind of virtual field that may have an opaque type.
+ """
+
+ transform: Optional[WriteTransform] = ir_data_fields.oneof_field("method")
+ """A transform is a way of turning a logical value into a value which should
+ be written to another field.
+
+ A virtual field like `let y = x + 1` would
+ have a transform WriteMethod to subtract 1 from the new `y` value, and
+ write that to `x`.
+ """
-@message
+@dataclasses.dataclass
class FieldLocation(Message):
"""IR for a field location."""
- start = Optional(Expression)
- size = Optional(Expression)
- source_location = Optional(Location)
+
+ start: Optional[Expression] = None
+ size: Optional[Expression] = None
+ source_location: Optional[Location] = None
-@message
-class Field(Message):
+@dataclasses.dataclass
+class Field(Message): # pylint:disable=too-many-instance-attributes
"""IR for a field in a struct definition.
There are two kinds of Field: physical fields have location and (physical)
@@ -809,159 +632,189 @@
and they can be freely intermingled in the source file.
"""
- location = Optional(FieldLocation) # The physical location of the field.
- type = Optional(Type) # The physical type of the field.
+ location: Optional[FieldLocation] = None
+ """The physical location of the field."""
+ type: Optional[Type] = None
+ """The physical type of the field."""
- read_transform = Optional(Expression) # The value of a virtual field.
+ read_transform: Optional[Expression] = None
+ """The value of a virtual field."""
- # How this virtual field should be written.
- write_method = Optional(WriteMethod)
+ write_method: Optional[WriteMethod] = None
+ """How this virtual field should be written."""
- name = Optional(NameDefinition) # The name of the field.
- abbreviation = Optional(Word) # An optional short name for the field, only
- # visible inside the enclosing bits/struct.
- attribute = Repeated(Attribute) # Field-specific attributes.
- documentation = Repeated(Documentation) # Field-specific documentation.
+ name: Optional[NameDefinition] = None
+ """The name of the field."""
+ abbreviation: Optional[Word] = None
+ """An optional short name for the field, only visible inside the enclosing bits/struct."""
+ attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
+ """Field-specific attributes."""
+ documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
+ """Field-specific documentation."""
- # The field only exists when existence_condition evaluates to true. For
- # example:
- #
- # struct Message:
- # 0 [+4] UInt length
- # 4 [+8] MessageType message_type
- # if message_type == MessageType.FOO:
- # 8 [+length] Foo foo
- # if message_type == MessageType.BAR:
- # 8 [+length] Bar bar
- # 8+length [+4] UInt crc
- #
- # For length, message_type, and crc, existence_condition will be
- # "boolean_constant { value: true }"
- #
- # For "foo", existence_condition will be:
- # function { function: EQUALITY
- # args: [reference to message_type]
- # args: { [reference to MessageType.FOO] } }
- #
- # The "bar" field will have a similar existence_condition to "foo":
- # function { function: EQUALITY
- # args: [reference to message_type]
- # args: { [reference to MessageType.BAR] } }
- #
- # When message_type is MessageType.BAR, the Message struct does not contain
- # field "foo", and vice versa for message_type == MessageType.FOO and field
- # "bar": those fields only conditionally exist in the structure.
- #
# TODO(bolms): Document conditional fields better, and replace some of this
# explanation with a reference to the documentation.
- existence_condition = Optional(Expression)
- source_location = Optional(Location)
+ existence_condition: Optional[Expression] = None
+ """The field only exists when existence_condition evaluates to true.
+
+ For example:
+ ```
+ struct Message:
+ 0 [+4] UInt length
+ 4 [+8] MessageType message_type
+ if message_type == MessageType.FOO:
+ 8 [+length] Foo foo
+ if message_type == MessageType.BAR:
+ 8 [+length] Bar bar
+ 8+length [+4] UInt crc
+ ```
+ For `length`, `message_type`, and `crc`, existence_condition will be
+ `boolean_constant { value: true }`
+
+ For `foo`, existence_condition will be:
+ ```
+ function { function: EQUALITY
+ args: [reference to message_type]
+ args: { [reference to MessageType.FOO] } }
+ ```
+
+ The `bar` field will have a similar existence_condition to `foo`:
+ ```
+ function { function: EQUALITY
+ args: [reference to message_type]
+ args: { [reference to MessageType.BAR] } }
+ ```
+
+ When `message_type` is `MessageType.BAR`, the `Message` struct does not contain
+ field `foo`, and vice versa for `message_type == MessageType.FOO` and field
+ `bar`: those fields only conditionally exist in the structure.
+ """
+
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class Structure(Message):
"""IR for a bits or struct definition."""
- field = Repeated(Field)
- # The fields in `field` are listed in the order they appear in the original
- # .emb.
- #
- # For text format output, this can lead to poor results. Take the following
- # struct:
- #
- # struct Foo:
- # b [+4] UInt a
- # 0 [+4] UInt b
- #
- # Here, the location of `a` depends on the current value of `b`. Because of
- # this, if someone calls
- #
- # emboss::UpdateFromText(foo_view, "{ a: 10, b: 4 }");
- #
- # then foo_view will not be updated the way one would expect: if `b`'s value
- # was something other than 4 to start with, then `UpdateFromText` will write
- # the 10 to some other location, then update `b` to 4.
- #
- # To avoid surprises, `emboss::DumpAsText` should return `"{ b: 4, a: 10
- # }"`.
- #
- # The `fields_in_dependency_order` field provides a permutation of `field`
- # such that each field appears after all of its dependencies. For example,
- # `struct Foo`, above, would have `{ 1, 0 }` in
- # `fields_in_dependency_order`.
- #
- # The exact ordering of `fields_in_dependency_order` is not guaranteed, but
- # some effort is made to keep the order close to the order fields are listed
- # in the original `.emb` file. In particular, if the ordering 0, 1, 2, 3,
- # ... satisfies dependency ordering, then `fields_in_dependency_order` will
- # be `{ 0, 1, 2, 3, ... }`.
- fields_in_dependency_order = Repeated(int)
+ field: list[Field] = ir_data_fields.list_field(Field)
- source_location = Optional(Location)
+ fields_in_dependency_order: list[int] = ir_data_fields.list_field(int)
+ """The fields in `field` are listed in the order they appear in the original
+ .emb.
+
+ For text format output, this can lead to poor results. Take the following
+ struct:
+ ```
+ struct Foo:
+ b [+4] UInt a
+ 0 [+4] UInt b
+ ```
+ Here, the location of `a` depends on the current value of `b`. Because of
+ this, if someone calls
+ ```
+ emboss::UpdateFromText(foo_view, "{ a: 10, b: 4 }");
+ ```
+ then foo_view will not be updated the way one would expect: if `b`'s value
+ was something other than 4 to start with, then `UpdateFromText` will write
+ the 10 to some other location, then update `b` to 4.
+
+ To avoid surprises, `emboss::DumpAsText` should return `"{ b: 4, a: 10
+ }"`.
+
+ The `fields_in_dependency_order` field provides a permutation of `field`
+ such that each field appears after all of its dependencies. For example,
+ `struct Foo`, above, would have `{ 1, 0 }` in
+ `fields_in_dependency_order`.
+
+ The exact ordering of `fields_in_dependency_order` is not guaranteed, but
+ some effort is made to keep the order close to the order fields are listed
+ in the original `.emb` file. In particular, if the ordering 0, 1, 2, 3,
+ ... satisfies dependency ordering, then `fields_in_dependency_order` will
+ be `{ 0, 1, 2, 3, ... }`.
+ """
+
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class External(Message):
"""IR for an external type declaration."""
+
# Externals have no values other than name and attribute list, which are
# common to all type definitions.
- source_location = Optional(Location)
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class EnumValue(Message):
"""IR for a single value within an enumerated type."""
- name = Optional(NameDefinition) # The name of the enum value.
- value = Optional(Expression) # The value of the enum value.
- documentation = Repeated(Documentation) # Value-specific documentation.
- attribute = Repeated(Attribute) # Value-specific attributes.
- source_location = Optional(Location)
+ name: Optional[NameDefinition] = None
+ """The name of the enum value."""
+ value: Optional[Expression] = None
+ """The value of the enum value."""
+ documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
+ """Value-specific documentation."""
+ attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
+ """Value-specific attributes."""
+
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class Enum(Message):
"""IR for an enumerated type definition."""
- value = Repeated(EnumValue)
- source_location = Optional(Location)
+
+ value: list[EnumValue] = ir_data_fields.list_field(EnumValue)
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class Import(Message):
"""IR for an import statement in a module."""
- file_name = Optional(String) # The file to import.
- local_name = Optional(Word) # The name to use within this module.
- source_location = Optional(Location)
+
+ file_name: Optional[String] = None
+ """The file to import."""
+ local_name: Optional[Word] = None
+ """The name to use within this module."""
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class RuntimeParameter(Message):
"""IR for a runtime parameter definition."""
- name = Optional(NameDefinition) # The name of the parameter.
- type = Optional(ExpressionType) # The type of the parameter.
- # For convenience and readability, physical types may be used in the .emb
- # source instead of a full expression type. That way, users can write
- # something like:
- #
- # struct Foo(version :: UInt:8):
- #
- # instead of:
- #
- # struct Foo(version :: {$int x |: 0 <= x <= 255}):
- #
- # In these cases, physical_type_alias holds the user-supplied type, and type
- # is filled in after initial parsing is finished.
- #
+ name: Optional[NameDefinition] = None
+ """The name of the parameter."""
+ type: Optional[ExpressionType] = None
+ """The type of the parameter."""
+
# TODO(bolms): Actually implement the set builder type notation.
- physical_type_alias = Optional(Type)
+ physical_type_alias: Optional[Type] = None
+ """For convenience and readability, physical types may be used in the .emb
+ source instead of a full expression type.
- source_location = Optional(Location)
+ That way, users can write
+ something like:
+ ```
+ struct Foo(version :: UInt:8):
+ ```
+ instead of:
+ ```
+ struct Foo(version :: {$int x |: 0 <= x <= 255}):
+ ```
+ In these cases, physical_type_alias holds the user-supplied type, and type
+ is filled in after initial parsing is finished.
+ """
+
+ source_location: Optional[Location] = None
class AddressableUnit(int, enum.Enum):
"""The "addressable unit" is the size of the smallest unit that can be read
+
from the backing store that this type expects. For `struct`s, this is
BYTE; for `enum`s and `bits`, this is BIT, and for `external`s it depends
on the specific type
@@ -972,48 +825,70 @@
BYTE = 8
-@message
+@dataclasses.dataclass
class TypeDefinition(Message):
"""Container IR for a type definition (struct, union, etc.)"""
- external = Optional(External, "type")
- enumeration = Optional(Enum, "type")
- structure = Optional(Structure, "type")
+ external: Optional[External] = ir_data_fields.oneof_field("type")
+ enumeration: Optional[Enum] = ir_data_fields.oneof_field("type")
+ structure: Optional[Structure] = ir_data_fields.oneof_field("type")
- name = Optional(NameDefinition) # The name of the type.
- attribute = Repeated(Attribute) # All attributes attached to the type.
- documentation = Repeated(Documentation) # Docs for the type.
+ name: Optional[NameDefinition] = None
+ """The name of the type."""
+ attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
+ """All attributes attached to the type."""
+ documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
+ """Docs for the type."""
# pylint:disable=undefined-variable
- subtype = Repeated(lambda: TypeDefinition) # Subtypes of this type.
- addressable_unit = Optional(AddressableUnit)
+ subtype: list["TypeDefinition"] = ir_data_fields.list_field(
+ lambda: TypeDefinition
+ )
+ """Subtypes of this type."""
+ addressable_unit: Optional[AddressableUnit] = None
- # If the type requires parameters at runtime, these are its parameters.
- # These are currently only allowed on structures, but in the future they
- # should be allowed on externals.
- runtime_parameter = Repeated(RuntimeParameter)
+ runtime_parameter: list[RuntimeParameter] = ir_data_fields.list_field(
+ RuntimeParameter
+ )
+ """If the type requires parameters at runtime, these are its parameters.
- source_location = Optional(Location)
+ These are currently only allowed on structures, but in the future they
+ should be allowed on externals.
+ """
+ source_location: Optional[Location] = None
-@message
+@dataclasses.dataclass
class Module(Message):
"""The IR for an individual Emboss module (file)."""
- attribute = Repeated(Attribute) # Module-level attributes.
- type = Repeated(TypeDefinition) # Module-level type definitions.
- documentation = Repeated(Documentation) # Module-level docs.
- foreign_import = Repeated(Import) # Other modules imported.
- source_text = Optional(_Text) # The original source code.
- source_location = Optional(Location) # Source code covered by this IR.
- source_file_name = Optional(_Text) # Name of the source file.
+
+ attribute: list[Attribute] = ir_data_fields.list_field(Attribute)
+ """Module-level attributes."""
+ type: list[TypeDefinition] = ir_data_fields.list_field(TypeDefinition)
+ """Module-level type definitions."""
+ documentation: list[Documentation] = ir_data_fields.list_field(Documentation)
+ """Module-level docs."""
+ foreign_import: list[Import] = ir_data_fields.list_field(Import)
+ """Other modules imported."""
+ source_text: Optional[str] = None
+ """The original source code."""
+ source_location: Optional[Location] = None
+ """Source code covered by this IR."""
+ source_file_name: Optional[str] = None
+ """Name of the source file."""
-@message
+@dataclasses.dataclass
class EmbossIr(Message):
"""The top-level IR for an Emboss module and all of its dependencies."""
- # All modules. The first entry will be the main module; back ends should
- # generate code corresponding to that module. The second entry will be the
- # prelude module.
- module = Repeated(Module)
+
+ module: list[Module] = ir_data_fields.list_field(Module)
+ """All modules.
+
+ The first entry will be the main module; back ends should
+ generate code corresponding to that module. The second entry will be the
+ prelude module.
+ """
-_initialize_deferred_specs()
+# Post-process the dataclasses to add cached fields.
+ir_data_fields.cache_message_specs(sys.modules[Message.__module__], Message)
diff --git a/compiler/util/ir_data_fields.py b/compiler/util/ir_data_fields.py
new file mode 100644
index 0000000..c74bd8f
--- /dev/null
+++ b/compiler/util/ir_data_fields.py
@@ -0,0 +1,449 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Provides helpers for defining and working with the fields of IR data classes.
+
+Field specs
+-----------
+Various utilities are provided for working with `dataclasses.Field`s:
+ - Classes:
+ - `FieldSpec` - used to track data about an IR data field
+ - `FieldContainer` - used to track the field container type
+ - Functions that work with `FieldSpec`s:
+ - `make_field_spec`, `build_default`
+ - Functions for retrieving a set of `FieldSpec`s for a given class:
+ - `field_specs`
+ - Functions for retrieving fields and their values:
+ - `fields_and_values`
+ - Functions for copying and updating IR data classes
+ - `copy`, `update`
+ - Functions to help defining IR data fields
+ - `oneof_field`, `list_field`, `str_field`
+"""
+
+import dataclasses
+import enum
+import sys
+from typing import (
+ Any,
+ Callable,
+ ClassVar,
+ ForwardRef,
+ Iterable,
+ Mapping,
+ MutableMapping,
+ NamedTuple,
+ Optional,
+ Protocol,
+ SupportsIndex,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+ get_args,
+ get_origin,
+)
+
+
+class IrDataclassInstance(Protocol):
+ """Type bound for an IR dataclass instance."""
+
+ __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]]
+ IR_DATACLASS: ClassVar[object]
+ field_specs: ClassVar["FilteredIrFieldSpecs"]
+
+
+IrDataT = TypeVar("IrDataT", bound=IrDataclassInstance)
+
+CopyValuesListT = TypeVar("CopyValuesListT", bound=type)
+
+_IR_DATACLASSS_ATTR = "IR_DATACLASS"
+
+
+def _is_ir_dataclass(obj):
+ return hasattr(obj, _IR_DATACLASSS_ATTR)
+
+
+class CopyValuesList(list[CopyValuesListT]):
+ """A list that makes copies of any value that is inserted"""
+
+ def __init__(
+ self, value_type: CopyValuesListT, iterable: Optional[Iterable] = None
+ ):
+ if iterable:
+ super().__init__(iterable)
+ else:
+ super().__init__()
+ self.value_type = value_type
+
+ def _copy(self, obj: Any):
+ if _is_ir_dataclass(obj):
+ return copy(obj)
+ return self.value_type(obj)
+
+ def extend(self, iterable: Iterable) -> None:
+ return super().extend([self._copy(i) for i in iterable])
+
+ def shallow_copy(self, iterable: Iterable) -> None:
+ """Explicitly performs a shallow copy of the provided list"""
+ return super().extend(iterable)
+
+ def append(self, obj: Any) -> None:
+ return super().append(self._copy(obj))
+
+ def insert(self, index: SupportsIndex, obj: Any) -> None:
+ return super().insert(index, self._copy(obj))
+
+
+class TemporaryCopyValuesList(NamedTuple):
+ """Class used to temporarily hold a CopyValuesList while copying and
+ constructing an IR dataclass.
+ """
+
+ temp_list: CopyValuesList
+
+
+class FieldContainer(enum.Enum):
+ """Indicates a fields container type"""
+
+ NONE = 0
+ OPTIONAL = 1
+ LIST = 2
+
+
+class FieldSpec(NamedTuple):
+ """Indicates the container and type of a field.
+
+ `FieldSpec` objects are accessed millions of times during runs so we cache as
+ many operations as possible.
+ - `is_dataclass`: `dataclasses.is_dataclass(data_type)`
+ - `is_sequence`: `container is FieldContainer.LIST`
+ - `is_enum`: `issubclass(data_type, enum.Enum)`
+ - `is_oneof`: `oneof is not None`
+
+ Use `make_field_spec` to automatically fill in the cached operations.
+ """
+
+ name: str
+ data_type: type
+ container: FieldContainer
+ oneof: Optional[str]
+ is_dataclass: bool
+ is_sequence: bool
+ is_enum: bool
+ is_oneof: bool
+
+
+def make_field_spec(
+ name: str, data_type: type, container: FieldContainer, oneof: Optional[str]
+):
+ """Builds a field spec with cached type queries."""
+ return FieldSpec(
+ name,
+ data_type,
+ container,
+ oneof,
+ is_dataclass=_is_ir_dataclass(data_type),
+ is_sequence=container is FieldContainer.LIST,
+ is_enum=issubclass(data_type, enum.Enum),
+ is_oneof=oneof is not None,
+ )
+
+
+def build_default(field_spec: FieldSpec):
+ """Builds a default instance of the given field"""
+ if field_spec.is_sequence:
+ return CopyValuesList(field_spec.data_type)
+ if field_spec.is_enum:
+ return field_spec.data_type(int())
+ return field_spec.data_type()
+
+
+class FilteredIrFieldSpecs:
+ """Provides cached views of an IR dataclass' fields."""
+
+ def __init__(self, specs: Mapping[str, FieldSpec]):
+ self.all_field_specs = specs
+ self.field_specs = tuple(specs.values())
+ self.dataclass_field_specs = {
+ k: v for k, v in specs.items() if v.is_dataclass
+ }
+ self.oneof_field_specs = {k: v for k, v in specs.items() if v.is_oneof}
+ self.sequence_field_specs = tuple(
+ v for v in specs.values() if v.is_sequence
+ )
+ self.oneof_mappings = tuple(
+ (k, v.oneof) for k, v in self.oneof_field_specs.items() if v.oneof
+ )
+
+
+def all_ir_classes(mod):
+ """Retrieves a list of all IR dataclass definitions in the given module."""
+ return (
+ v
+ for v in mod.__dict__.values()
+ if isinstance(type, v.__class__) and _is_ir_dataclass(v)
+ )
+
+
+class IrDataclassSpecs:
+ """Maintains a cache of all IR dataclass specs."""
+
+ spec_cache: MutableMapping[type, FilteredIrFieldSpecs] = {}
+
+ @classmethod
+ def get_mod_specs(cls, mod):
+ """Gets the IR dataclass specs for the given module."""
+ return {
+ ir_class: FilteredIrFieldSpecs(_field_specs(ir_class))
+ for ir_class in all_ir_classes(mod)
+ }
+
+ @classmethod
+ def get_specs(cls, data_class):
+ """Gets the field specs for the given class. The specs will be cached."""
+ if data_class not in cls.spec_cache:
+ mod = sys.modules[data_class.__module__]
+ cls.spec_cache.update(cls.get_mod_specs(mod))
+ return cls.spec_cache[data_class]
+
+
+def cache_message_specs(mod, cls):
+ """Adds a cached `field_specs` attribute to IR dataclasses in `mod`
+ excluding the given base `cls`.
+
+ This needs to be done after the dataclass decorators run and create the
+ wrapped classes.
+ """
+ for data_class in all_ir_classes(mod):
+ if data_class is not cls:
+ data_class.field_specs = IrDataclassSpecs.get_specs(data_class)
+
+
+def _field_specs(cls: type[IrDataT]) -> Mapping[str, FieldSpec]:
+ """Gets the IR data field names and types for the given IR data class"""
+ # Get the dataclass fields
+ class_fields = dataclasses.fields(cast(Any, cls))
+
+ # Pre-python 3.11 (maybe pre 3.10) `get_type_hints` will substitute
+ # `builtins.Expression` for 'Expression' rather than `ir_data.Expression`.
+ # Instead we manually subsitute the type by extracting the list of classes
+ # from the class' module and manually substituting.
+ mod_ns = {
+ k: v
+ for k, v in sys.modules[cls.__module__].__dict__.items()
+ if isinstance(type, v.__class__)
+ }
+
+ # Now extract the concrete type out of optionals
+ result: MutableMapping[str, FieldSpec] = {}
+ for class_field in class_fields:
+ if class_field.name.startswith("_"):
+ continue
+ container_type = FieldContainer.NONE
+ type_hint = class_field.type
+ oneof = class_field.metadata.get("oneof")
+
+ # Check if this type is wrapped
+ origin = get_origin(type_hint)
+ # Get the wrapped types if there are any
+ args = get_args(type_hint)
+ if origin is not None:
+ # Extract the type.
+ type_hint = args[0]
+
+ # Underneath the hood `typing.Optional` is just a `Union[T, None]` so we
+ # have to check if it's a `Union` instead of just using `Optional`.
+ if origin == Union:
+ # Make sure this is an `Optional` and not another `Union` type.
+ assert len(args) == 2 and args[1] == type(None)
+ container_type = FieldContainer.OPTIONAL
+ elif origin == list:
+ container_type = FieldContainer.LIST
+ else:
+ raise TypeError(f"Field has invalid container type: {origin}")
+
+ # Resolve any forward references.
+ if isinstance(type_hint, str):
+ type_hint = mod_ns[type_hint]
+ if isinstance(type_hint, ForwardRef):
+ type_hint = mod_ns[type_hint.__forward_arg__]
+
+ result[class_field.name] = make_field_spec(
+ class_field.name, type_hint, container_type, oneof
+ )
+
+ return result
+
+
+def field_specs(obj: IrDataT | type[IrDataT]) -> Mapping[str, FieldSpec]:
+ """Retrieves the fields specs for the the give data type.
+
+ The results of this method are cached to reduce lookup overhead.
+ """
+ cls = obj if isinstance(obj, type) else type(obj)
+ if cls is type(None):
+ raise TypeError("field_specs called with invalid type: NoneType")
+ return IrDataclassSpecs.get_specs(cls).all_field_specs
+
+
+def fields_and_values(
+ ir: IrDataT,
+ value_filt: Optional[Callable[[Any], bool]] = None,
+):
+ """Retrieves the fields and their values for a given IR data class.
+
+ Args:
+ ir: The IR data class or a read-only wrapper of an IR data class.
+ value_filt: Optional filter used to exclude values.
+ """
+ set_fields: list[Tuple[FieldSpec, Any]] = []
+ specs: FilteredIrFieldSpecs = ir.field_specs
+ for spec in specs.field_specs:
+ value = getattr(ir, spec.name)
+ if not value_filt or value_filt(value):
+ set_fields.append((spec, value))
+ return set_fields
+
+
+# `copy` is one of the hottest paths of embossc. We've taken steps to
+# optimize this path at the expense of code clarity and modularization.
+#
+# 1. `FilteredFieldSpecs` are cached on in the class definition for IR
+# dataclasses under the `ir_data.Message.field_specs` class attribute. We
+# just assume the passed in object has that attribute.
+# 2. We cache a `field_specs` entry that is just the `values()` of the
+# `all_field_specs` dict.
+# 3. Copied lists are wrapped in a `TemporaryCopyValuesList`. This is used to
+# signal to consumers that they can take ownership of the contained list
+# rather than copying it again. See `ir_data.Message()` and `udpate()` for
+# where this is used.
+# 4. `FieldSpec` checks are cached including `is_dataclass` and `is_sequence`.
+# 5. None checks are only done in `copy()`, `_copy_set_fields` only
+# references `_copy()` to avoid this step.
+def _copy_set_fields(ir: IrDataT):
+ values: MutableMapping[str, Any] = {}
+
+ specs: FilteredIrFieldSpecs = ir.field_specs
+ for spec in specs.field_specs:
+ value = getattr(ir, spec.name)
+ if value is not None:
+ if spec.is_sequence:
+ if spec.is_dataclass:
+ copy_value = CopyValuesList(spec.data_type, (_copy(v) for v in value))
+ value = TemporaryCopyValuesList(copy_value)
+ else:
+ copy_value = CopyValuesList(spec.data_type, value)
+ value = TemporaryCopyValuesList(copy_value)
+ elif spec.is_dataclass:
+ value = _copy(value)
+ values[spec.name] = value
+ return values
+
+
+def _copy(ir: IrDataT) -> IrDataT:
+ return type(ir)(**_copy_set_fields(ir)) # type: ignore[misc]
+
+
+def copy(ir: IrDataT) -> IrDataT | None:
+ """Creates a copy of the given IR data class"""
+ if not ir:
+ return None
+ return _copy(ir)
+
+
+def update(ir: IrDataT, template: IrDataT):
+ """Updates `ir`s fields with all set fields in the template."""
+ for k, v in _copy_set_fields(template).items():
+ if isinstance(v, TemporaryCopyValuesList):
+ v = v.temp_list
+ setattr(ir, k, v)
+
+
+class OneOfField:
+ """Decorator for a "oneof" field.
+
+ Tracks when the field is set and will unset othe fields in the associated
+ oneof group.
+
+ Note: Decorators only work if dataclass slots aren't used.
+ """
+
+ def __init__(self, oneof: str) -> None:
+ super().__init__()
+ self.oneof = oneof
+ self.owner_type = None
+ self.proxy_name: str = ""
+ self.name: str = ""
+
+ def __set_name__(self, owner, name):
+ self.name = name
+ self.proxy_name = f"_{name}"
+ self.owner_type = owner
+ # Add our empty proxy field to the class.
+ setattr(owner, self.proxy_name, None)
+
+ def __get__(self, obj, objtype=None):
+ return getattr(obj, self.proxy_name)
+
+ def __set__(self, obj, value):
+ if value is self:
+ # This will happen if the dataclass uses the default value, we just
+ # default to None.
+ value = None
+
+ if value is not None:
+ # Clear the others
+ for name, oneof in IrDataclassSpecs.get_specs(
+ self.owner_type
+ ).oneof_mappings:
+ if oneof == self.oneof and name != self.name:
+ setattr(obj, name, None)
+
+ setattr(obj, self.proxy_name, value)
+
+
+def oneof_field(name: str):
+ """Alternative for `datclasses.field` that sets up a oneof variable"""
+ return dataclasses.field( # pylint:disable=invalid-field-call
+ default=OneOfField(name), metadata={"oneof": name}, init=True
+ )
+
+
+def str_field():
+ """Helper used to define a defaulted str field"""
+ return dataclasses.field(default_factory=str) # pylint:disable=invalid-field-call
+
+
+def list_field(cls_or_fn):
+ """Helper used to define a defaulted list field.
+
+ A lambda can be used to defer resolution of a field type that references its
+ container type, for example:
+ ```
+ class Foo:
+ subtypes: list['Foo'] = list_field(lambda: Foo)
+ names: list[str] = list_field(str)
+ ```
+
+ Args:
+ cls_or_fn: The class type or a function that resolves to the class type.
+ """
+
+ def list_factory(c):
+ return CopyValuesList(c if isinstance(c, type) else c())
+
+ return dataclasses.field( # pylint:disable=invalid-field-call
+ default_factory=lambda: list_factory(cls_or_fn)
+ )
diff --git a/compiler/util/ir_data_fields_test.py b/compiler/util/ir_data_fields_test.py
new file mode 100644
index 0000000..fa99943
--- /dev/null
+++ b/compiler/util/ir_data_fields_test.py
@@ -0,0 +1,254 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for util.ir_data_fields."""
+
+import dataclasses
+import enum
+import sys
+from typing import Optional
+import unittest
+
+from compiler.util import ir_data
+from compiler.util import ir_data_fields
+
+
+class TestEnum(enum.Enum):
+ """Used to test python Enum handling."""
+
+ UNKNOWN = 0
+ VALUE_1 = 1
+ VALUE_2 = 2
+
+
+@dataclasses.dataclass
+class Opaque(ir_data.Message):
+ """Used for testing data field helpers"""
+
+
+@dataclasses.dataclass
+class ClassWithUnion(ir_data.Message):
+ """Used for testing data field helpers"""
+
+ opaque: Optional[Opaque] = ir_data_fields.oneof_field("type")
+ integer: Optional[int] = ir_data_fields.oneof_field("type")
+ boolean: Optional[bool] = ir_data_fields.oneof_field("type")
+ enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type")
+ non_union_field: int = 0
+
+
+@dataclasses.dataclass
+class ClassWithTwoUnions(ir_data.Message):
+ """Used for testing data field helpers"""
+
+ opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1")
+ integer: Optional[int] = ir_data_fields.oneof_field("type_1")
+ boolean: Optional[bool] = ir_data_fields.oneof_field("type_2")
+ enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2")
+ non_union_field: int = 0
+ seq_field: list[int] = ir_data_fields.list_field(int)
+
+
+@dataclasses.dataclass
+class NestedClass(ir_data.Message):
+ """Used for testing data field helpers"""
+
+ one_union_class: Optional[ClassWithUnion] = None
+ two_union_class: Optional[ClassWithTwoUnions] = None
+
+
+@dataclasses.dataclass
+class ListCopyTestClass(ir_data.Message):
+ """Used to test behavior or extending a sequence."""
+
+ non_union_field: int = 0
+ seq_field: list[int] = ir_data_fields.list_field(int)
+
+
+@dataclasses.dataclass
+class OneofFieldTest(ir_data.Message):
+ """Basic test class for oneof fields"""
+
+ int_field_1: Optional[int] = ir_data_fields.oneof_field("type_1")
+ int_field_2: Optional[int] = ir_data_fields.oneof_field("type_1")
+ normal_field: bool = True
+
+
+class OneOfTest(unittest.TestCase):
+ """Tests for the the various oneof field helpers"""
+
+ def test_field_attribute(self):
+ """Test the `oneof_field` helper."""
+ test_field = ir_data_fields.oneof_field("type_1")
+ self.assertIsNotNone(test_field)
+ self.assertTrue(test_field.init)
+ self.assertIsInstance(test_field.default, ir_data_fields.OneOfField)
+ self.assertEqual(test_field.metadata.get("oneof"), "type_1")
+
+ def test_init_default(self):
+ """Test creating an instance with default fields"""
+ one_of_field_test = OneofFieldTest()
+ self.assertIsNone(one_of_field_test.int_field_1)
+ self.assertIsNone(one_of_field_test.int_field_2)
+ self.assertTrue(one_of_field_test.normal_field)
+
+ def test_init(self):
+ """Test creating an instance with non-default fields"""
+ one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False)
+ self.assertEqual(one_of_field_test.int_field_1, 10)
+ self.assertIsNone(one_of_field_test.int_field_2)
+ self.assertFalse(one_of_field_test.normal_field)
+
+ def test_set_oneof_field(self):
+ """Tests setting oneof fields causes others in the group to be unset"""
+ one_of_field_test = OneofFieldTest()
+ one_of_field_test.int_field_1 = 10
+ self.assertEqual(one_of_field_test.int_field_1, 10)
+ self.assertEqual(one_of_field_test.int_field_2, None)
+ one_of_field_test.int_field_2 = 20
+ self.assertEqual(one_of_field_test.int_field_1, None)
+ self.assertEqual(one_of_field_test.int_field_2, 20)
+
+ # Do it again
+ one_of_field_test.int_field_1 = 10
+ self.assertEqual(one_of_field_test.int_field_1, 10)
+ self.assertEqual(one_of_field_test.int_field_2, None)
+ one_of_field_test.int_field_2 = 20
+ self.assertEqual(one_of_field_test.int_field_1, None)
+ self.assertEqual(one_of_field_test.int_field_2, 20)
+
+ # Now create a new instance and make sure changes to it are not reflected
+ # on the original object.
+ one_of_field_test_2 = OneofFieldTest()
+ one_of_field_test_2.int_field_1 = 1000
+ self.assertEqual(one_of_field_test_2.int_field_1, 1000)
+ self.assertEqual(one_of_field_test_2.int_field_2, None)
+ self.assertEqual(one_of_field_test.int_field_1, None)
+ self.assertEqual(one_of_field_test.int_field_2, 20)
+
+ def test_set_to_none(self):
+ """Tests explicitly setting a oneof field to None"""
+ one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False)
+ self.assertEqual(one_of_field_test.int_field_1, 10)
+ self.assertIsNone(one_of_field_test.int_field_2)
+ self.assertFalse(one_of_field_test.normal_field)
+
+ # Clear the set fields
+ one_of_field_test.int_field_1 = None
+ self.assertIsNone(one_of_field_test.int_field_1)
+ self.assertIsNone(one_of_field_test.int_field_2)
+ self.assertFalse(one_of_field_test.normal_field)
+
+ # Set another field
+ one_of_field_test.int_field_2 = 200
+ self.assertIsNone(one_of_field_test.int_field_1)
+ self.assertEqual(one_of_field_test.int_field_2, 200)
+ self.assertFalse(one_of_field_test.normal_field)
+
+ # Clear the already unset field
+ one_of_field_test.int_field_1 = None
+ self.assertIsNone(one_of_field_test.int_field_1)
+ self.assertEqual(one_of_field_test.int_field_2, 200)
+ self.assertFalse(one_of_field_test.normal_field)
+
+ def test_oneof_specs(self):
+ """Tests the `oneof_field_specs` filter"""
+ expected = {
+ "int_field_1": ir_data_fields.make_field_spec(
+ "int_field_1", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1"
+ ),
+ "int_field_2": ir_data_fields.make_field_spec(
+ "int_field_2", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1"
+ ),
+ }
+ actual = ir_data_fields.IrDataclassSpecs.get_specs(
+ OneofFieldTest
+ ).oneof_field_specs
+ self.assertDictEqual(actual, expected)
+
+ def test_oneof_mappings(self):
+ """Tests the `oneof_mappings` function"""
+ expected = (("int_field_1", "type_1"), ("int_field_2", "type_1"))
+ actual = ir_data_fields.IrDataclassSpecs.get_specs(
+ OneofFieldTest
+ ).oneof_mappings
+ self.assertTupleEqual(actual, expected)
+
+
+class IrDataFieldsTest(unittest.TestCase):
+ """Tests misc methods in ir_data_fields"""
+
+ def test_copy(self):
+ """Tests copying a data class works as expected"""
+ union = ClassWithTwoUnions(
+ opaque=Opaque(), boolean=True, non_union_field=10, seq_field=[1, 2, 3]
+ )
+ nested_class = NestedClass(two_union_class=union)
+ nested_class_copy = ir_data_fields.copy(nested_class)
+ self.assertIsNotNone(nested_class_copy)
+ self.assertIsNot(nested_class, nested_class_copy)
+ self.assertEqual(nested_class_copy, nested_class)
+
+ empty_copy = ir_data_fields.copy(None)
+ self.assertIsNone(empty_copy)
+
+ def test_copy_values_list(self):
+ """Tests that CopyValuesList copies values"""
+ data_list = ir_data_fields.CopyValuesList(ListCopyTestClass)
+ self.assertEqual(len(data_list), 0)
+
+ list_test = ListCopyTestClass(non_union_field=2, seq_field=[5, 6, 7])
+ list_tests = [ir_data_fields.copy(list_test) for _ in range(4)]
+ data_list.extend(list_tests)
+ self.assertEqual(len(data_list), 4)
+ for i in data_list:
+ self.assertEqual(i, list_test)
+
+ def test_list_param_is_copied(self):
+ """Test that lists passed to constructors are converted to CopyValuesList"""
+ seq_field = [5, 6, 7]
+ list_test = ListCopyTestClass(non_union_field=2, seq_field=seq_field)
+ self.assertEqual(len(list_test.seq_field), len(seq_field))
+ self.assertIsNot(list_test.seq_field, seq_field)
+ self.assertEqual(list_test.seq_field, seq_field)
+ self.assertTrue(
+ isinstance(list_test.seq_field, ir_data_fields.CopyValuesList)
+ )
+
+ def test_copy_oneof(self):
+ """Tests copying an IR data class that has oneof fields."""
+ oneof_test = OneofFieldTest()
+ oneof_test.int_field_1 = 10
+ oneof_test.normal_field = False
+ self.assertEqual(oneof_test.int_field_1, 10)
+ self.assertEqual(oneof_test.normal_field, False)
+
+ oneof_copy = ir_data_fields.copy(oneof_test)
+ self.assertIsNotNone(oneof_copy)
+ self.assertEqual(oneof_copy.int_field_1, 10)
+ self.assertIsNone(oneof_copy.int_field_2)
+ self.assertEqual(oneof_copy.normal_field, False)
+
+ oneof_copy.int_field_2 = 100
+ self.assertEqual(oneof_copy.int_field_2, 100)
+ self.assertIsNone(oneof_copy.int_field_1)
+ self.assertEqual(oneof_test.int_field_1, 10)
+ self.assertEqual(oneof_test.normal_field, False)
+
+
+ir_data_fields.cache_message_specs(
+ sys.modules[OneofFieldTest.__module__], ir_data.Message)
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py
index ac02bb0..7de1979 100644
--- a/compiler/util/ir_data_utils.py
+++ b/compiler/util/ir_data_utils.py
@@ -12,51 +12,455 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+"""Provides a helpers for working with IR data elements.
+
+Historical note: At one point protocol buffers were used for IR data. The
+codebase still expects the IR data classes to behave similarly, particularly
+with respect to "autovivification" where accessing an undefined field will
+create it temporarily and add it if assigned to. Though, perhaps not fully
+following the Pythonic ethos, we provide this behavior via the `builder` and
+`reader` helpers to remain compatible with the rest of the codebase.
+
+builder
+-------
+Instead of:
+```
+def set_function_name_end(function: Function):
+ if not function.function_name:
+ function.function_name = Word()
+ if not function.function_name.source_location:
+ function.function_name.source_location = Location()
+ word.source_location.end = Position(line=1,column=2)
+```
+
+We can do:
+```
+def set_function_name_end(function: Function):
+ builder(function).function_name.source_location.end = Position(line=1,
+ column=2)
+```
+
+reader
+------
+Instead of:
+```
+def is_leaf_synthetic(data):
+ if data:
+ if data.attribute:
+ if data.attribute.value:
+ if data.attribute.value.is_synthetic is not None:
+ return data.attribute.value.is_synthetic
+ return False
+```
+We can do:
+```
+def is_leaf_synthetic(data):
+ return reader(data).attribute.value.is_synthetic
+```
+
+IrDataSerializer
+----------------
+Provides methods for serializing and deserializing an IR data object.
+"""
+import enum
+import json
+from typing import (
+ Any,
+ Callable,
+ Generic,
+ MutableMapping,
+ MutableSequence,
+ Optional,
+ Tuple,
+ TypeVar,
+ cast,
+)
+
from compiler.util import ir_data
+from compiler.util import ir_data_fields
+
+
+MessageT = TypeVar("MessageT", bound=ir_data.Message)
+
+
+def field_specs(ir: MessageT | type[MessageT]):
+ """Retrieves the field specs for the IR data class"""
+ data_type = ir if isinstance(ir, type) else type(ir)
+ return ir_data_fields.IrDataclassSpecs.get_specs(data_type).all_field_specs
+
class IrDataSerializer:
"""Provides methods for serializing IR data objects"""
- def __init__(self, ir: ir_data.Message):
+ def __init__(self, ir: MessageT):
assert ir is not None
self.ir = ir
+ def _to_dict(
+ self,
+ ir: MessageT,
+ field_func: Callable[
+ [MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]]
+ ],
+ ) -> MutableMapping[str, Any]:
+ assert ir is not None
+ values: MutableMapping[str, Any] = {}
+ for spec, value in field_func(ir):
+ if value is not None and spec.is_dataclass:
+ if spec.is_sequence:
+ value = [self._to_dict(v, field_func) for v in value]
+ else:
+ value = self._to_dict(value, field_func)
+ values[spec.name] = value
+ return values
+
+ def to_dict(self, exclude_none: bool = False):
+ """Converts the IR data class to a dictionary."""
+
+ def non_empty(ir):
+ return fields_and_values(
+ ir, lambda v: v is not None and (not isinstance(v, list) or len(v))
+ )
+
+ def all_fields(ir):
+ return fields_and_values(ir)
+
+ # It's tempting to use `dataclasses.asdict` here, but that does a deep
+ # copy which is overkill for the current usage; mainly as an intermediary
+ # for `to_json` and `repr`.
+ return self._to_dict(self.ir, non_empty if exclude_none else all_fields)
+
def to_json(self, *args, **kwargs):
"""Converts the IR data class to a JSON string"""
- return self.ir.to_json(*args, **kwargs)
+ return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs)
@staticmethod
- def from_json(data_cls: type[ir_data.Message], data):
+ def from_json(data_cls, data):
"""Constructs an IR data class from the given JSON string"""
- return data_cls.from_json(data)
+ as_dict = json.loads(data)
+ return IrDataSerializer.from_dict(data_cls, as_dict)
+
+ def copy_from_dict(self, data):
+ """Deserializes the data and overwrites the IR data class with it"""
+ cls = type(self.ir)
+ data_copy = IrDataSerializer.from_dict(cls, data)
+ for k in field_specs(cls):
+ setattr(self.ir, k, getattr(data_copy, k))
+
+ @staticmethod
+ def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum:
+ if isinstance(val, str):
+ return getattr(enum_cls, val)
+ return enum_cls(val)
+
+ @staticmethod
+ def _enum_type_hook(enum_cls: type[enum.Enum]):
+ return lambda val: IrDataSerializer._enum_type_converter(enum_cls, val)
+
+ @staticmethod
+ def _from_dict(data_cls: type[MessageT], data):
+ class_fields: MutableMapping[str, Any] = {}
+ for name, spec in ir_data_fields.field_specs(data_cls).items():
+ if (value := data.get(name)) is not None:
+ if spec.is_dataclass:
+ if spec.is_sequence:
+ class_fields[name] = [
+ IrDataSerializer._from_dict(spec.data_type, v) for v in value
+ ]
+ else:
+ class_fields[name] = IrDataSerializer._from_dict(
+ spec.data_type, value
+ )
+ else:
+ if spec.data_type in (
+ ir_data.FunctionMapping,
+ ir_data.AddressableUnit,
+ ):
+ class_fields[name] = IrDataSerializer._enum_type_converter(
+ spec.data_type, value
+ )
+ else:
+ if spec.is_sequence:
+ class_fields[name] = value
+ else:
+ class_fields[name] = spec.data_type(value)
+ return data_cls(**class_fields)
+
+ @staticmethod
+ def from_dict(data_cls: type[MessageT], data):
+ """Creates a new IR data instance from a serialized dict"""
+ return IrDataSerializer._from_dict(data_cls, data)
-def builder(target: ir_data.Message) -> ir_data.Message:
- """Provides a wrapper for building up IR data classes.
+class _IrDataSequenceBuilder(MutableSequence[MessageT]):
+ """Wrapper for a list of IR elements
- This is a no-op and just used for annotation for now.
+ Simply wraps the returned values during indexed access and iteration with
+ IrDataBuilders.
"""
- return target
+
+ def __init__(self, target: MutableSequence[MessageT]):
+ self._target = target
+
+ def __delitem__(self, key):
+ del self._target[key]
+
+ def __getitem__(self, key):
+ return _IrDataBuilder(self._target.__getitem__(key))
+
+ def __setitem__(self, key, value):
+ self._target[key] = value
+
+ def __iter__(self):
+ itr = iter(self._target)
+ for i in itr:
+ yield _IrDataBuilder(i)
+
+ def __repr__(self):
+ return repr(self._target)
+
+ def __len__(self):
+ return len(self._target)
+
+ def __eq__(self, other):
+ return self._target == other
+
+ def __ne__(self, other):
+ return self._target != other
+
+ def insert(self, index, value):
+ self._target.insert(index, value)
+
+ def extend(self, values):
+ self._target.extend(values)
-def reader(obj: ir_data.Message) -> ir_data.Message:
- """A read-only wrapper for querying chains of IR data fields.
+class _IrDataBuilder(Generic[MessageT]):
+ """Wrapper for an IR element"""
- This is a no-op and just used for annotation for now.
+ def __init__(self, ir: MessageT) -> None:
+ assert ir is not None
+ self.ir: MessageT = ir
+
+ def __setattr__(self, __name: str, __value: Any) -> None:
+ if __name == "ir":
+ # This our proxy object
+ object.__setattr__(self, __name, __value)
+ else:
+ # Passthrough to the proxy object
+ ir: MessageT = object.__getattribute__(self, "ir")
+ setattr(ir, __name, __value)
+
+ def __getattribute__(self, name: str) -> Any:
+ """Hook for `getattr` that handles adding missing fields.
+
+ If the field is missing inserts it, and then returns either the raw value
+ for basic types
+ or a new IrBuilder wrapping the field to handle the next field access in a
+ longer chain.
+ """
+
+ # Check if getting one of the builder attributes
+ if name in ("CopyFrom", "ir"):
+ return object.__getattribute__(self, name)
+
+ # Get our target object by bypassing our getattr hook
+ ir: MessageT = object.__getattribute__(self, "ir")
+ if ir is None:
+ return object.__getattribute__(self, name)
+
+ if name in ("HasField", "WhichOneof"):
+ return getattr(ir, name)
+
+ field_spec = field_specs(ir).get(name)
+ if field_spec is None:
+ raise AttributeError(
+ f"No field {name} on {type(ir).__module__}.{type(ir).__name__}."
+ )
+
+ obj = getattr(ir, name, None)
+ if obj is None:
+ # Create a default and store it
+ obj = ir_data_fields.build_default(field_spec)
+ setattr(ir, name, obj)
+
+ if field_spec.is_dataclass:
+ obj = (
+ _IrDataSequenceBuilder(obj)
+ if field_spec.is_sequence
+ else _IrDataBuilder(obj)
+ )
+
+ return obj
+
+ def CopyFrom(self, template: MessageT): # pylint:disable=invalid-name
+ """Updates the fields of this class with values set in the template"""
+ update(cast(type[MessageT], self), template)
+
+
+def builder(target: MessageT) -> MessageT:
+ """Create a wrapper around the target to help build an IR Data structure"""
+ # Check if the target is already a builder.
+ if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)):
+ return target
+
+ # Builders are only valid for IR data classes.
+ if not hasattr(type(target), "IR_DATACLASS"):
+ raise TypeError(f"Builder target {type(target)} is not an ir_data.message")
+
+ # Create a builder and cast it to the target type to expose type hinting for
+ # the wrapped type.
+ return cast(MessageT, _IrDataBuilder(target))
+
+
+def _field_checker_from_spec(spec: ir_data_fields.FieldSpec):
+ """Helper that builds an FieldChecker that pretends to be an IR class"""
+ if spec.is_sequence:
+ return []
+ if spec.is_dataclass:
+ return _ReadOnlyFieldChecker(spec)
+ return ir_data_fields.build_default(spec)
+
+
+def _field_type(ir_or_spec: MessageT | ir_data_fields.FieldSpec) -> type:
+ if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
+ return ir_or_spec.data_type
+ return type(ir_or_spec)
+
+
+class _ReadOnlyFieldChecker:
+ """Class used the chain calls to fields that aren't set"""
+
+ def __init__(self, ir_or_spec: MessageT | ir_data_fields.FieldSpec) -> None:
+ self.ir_or_spec = ir_or_spec
+
+ def __setattr__(self, name: str, value: Any) -> None:
+ if name == "ir_or_spec":
+ return object.__setattr__(self, name, value)
+
+ raise AttributeError(f"Cannot set {name} on read-only wrapper")
+
+ def __getattribute__(self, name: str) -> Any: # pylint:disable=too-many-return-statements
+ ir_or_spec = object.__getattribute__(self, "ir_or_spec")
+ if name == "ir_or_spec":
+ return ir_or_spec
+
+ field_type = _field_type(ir_or_spec)
+ spec = field_specs(field_type).get(name)
+ if not spec:
+ if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
+ if name == "HasField":
+ return lambda x: False
+ if name == "WhichOneof":
+ return lambda x: None
+ return object.__getattribute__(ir_or_spec, name)
+
+ if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
+ # Just pretending
+ return _field_checker_from_spec(spec)
+
+ value = getattr(ir_or_spec, name)
+ if value is None:
+ return _field_checker_from_spec(spec)
+
+ if spec.is_dataclass:
+ if spec.is_sequence:
+ return [_ReadOnlyFieldChecker(i) for i in value]
+ return _ReadOnlyFieldChecker(value)
+
+ return value
+
+ def __eq__(self, other):
+ if isinstance(other, _ReadOnlyFieldChecker):
+ other = other.ir_or_spec
+ return self.ir_or_spec == other
+
+ def __ne__(self, other):
+ return not self == other
+
+
+def reader(obj: MessageT | _ReadOnlyFieldChecker) -> MessageT:
+ """Builds a read-only wrapper that can be used to check chains of possibly
+ unset fields.
+
+ This wrapper explicitly does not alter the wrapped object and is only
+ intended for reading contents.
+
+ For example, a `reader` lets you do:
+ ```
+ def get_function_name_end_column(function: ir_data.Function):
+ return reader(function).function_name.source_location.end.column
+ ```
+
+ Instead of:
+ ```
+ def get_function_name_end_column(function: ir_data.Function):
+ if function.function_name:
+ if function.function_name.source_location:
+ if function.function_name.source_location.end:
+ return function.function_name.source_location.end.column
+ return 0
+ ```
"""
- return obj
+ # Create a read-only wrapper if it's not already one.
+ if not isinstance(obj, _ReadOnlyFieldChecker):
+ obj = _ReadOnlyFieldChecker(obj)
+
+ # Cast it back to the original type.
+ return cast(MessageT, obj)
-def copy(ir: ir_data.Message | None) -> ir_data.Message | None:
+def _extract_ir(
+ ir_or_wrapper: MessageT | _ReadOnlyFieldChecker | _IrDataBuilder | None,
+) -> ir_data_fields.IrDataclassInstance | None:
+ if isinstance(ir_or_wrapper, _ReadOnlyFieldChecker):
+ ir_or_spec = ir_or_wrapper.ir_or_spec
+ if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
+ # This is a placeholder entry, no fields are set.
+ return None
+ ir_or_wrapper = ir_or_spec
+ elif isinstance(ir_or_wrapper, _IrDataBuilder):
+ ir_or_wrapper = ir_or_wrapper.ir
+ return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper)
+
+
+def fields_and_values(
+ ir_wrapper: MessageT | _ReadOnlyFieldChecker,
+ value_filt: Optional[Callable[[Any], bool]] = None,
+) -> list[Tuple[ir_data_fields.FieldSpec, Any]]:
+ """Retrieves the fields and their values for a given IR data class.
+
+ Args:
+ ir: The IR data class or a read-only wrapper of an IR data class.
+ value_filt: Optional filter used to exclude values.
+ """
+ if (ir := _extract_ir(ir_wrapper)) is None:
+ return []
+
+ return ir_data_fields.fields_and_values(ir, value_filt)
+
+
+def get_set_fields(ir: MessageT):
+ """Retrieves the field spec and value of fields that are set in the given IR data class.
+
+ A value is considered "set" if it is not None.
+ """
+ return fields_and_values(ir, lambda v: v is not None)
+
+
+def copy(ir_wrapper: MessageT | None) -> MessageT | None:
"""Creates a copy of the given IR data class"""
- if not ir:
+ if (ir := _extract_ir(ir_wrapper)) is None:
return None
- ir_class = type(ir)
- ir_copy = ir_class()
- update(ir_copy, ir)
- return ir_copy
+ ir_copy = ir_data_fields.copy(ir)
+ return cast(MessageT, ir_copy)
-def update(ir: ir_data.Message, template: ir_data.Message):
+def update(ir: MessageT, template: MessageT):
"""Updates `ir`s fields with all set fields in the template."""
- ir.CopyFrom(template)
+ if not (template_ir := _extract_ir(template)):
+ return
+
+ ir_data_fields.update(
+ cast(ir_data_fields.IrDataclassInstance, ir), template_ir
+ )
diff --git a/compiler/util/ir_data_utils_test.py b/compiler/util/ir_data_utils_test.py
new file mode 100644
index 0000000..1c000ec
--- /dev/null
+++ b/compiler/util/ir_data_utils_test.py
@@ -0,0 +1,650 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Tests for util.ir_data_utils."""
+
+import dataclasses
+import enum
+import sys
+from typing import Optional
+import unittest
+from compiler.util import expression_parser
+from compiler.util import ir_data
+from compiler.util import ir_data_fields
+from compiler.util import ir_data_utils
+
+
+class TestEnum(enum.Enum):
+ """Used to test python Enum handling."""
+
+ UNKNOWN = 0
+ VALUE_1 = 1
+ VALUE_2 = 2
+
+
+@dataclasses.dataclass
+class Opaque(ir_data.Message):
+ """Used for testing data field helpers"""
+
+
+@dataclasses.dataclass
+class ClassWithUnion(ir_data.Message):
+ """Used for testing data field helpers"""
+
+ opaque: Optional[Opaque] = ir_data_fields.oneof_field("type")
+ integer: Optional[int] = ir_data_fields.oneof_field("type")
+ boolean: Optional[bool] = ir_data_fields.oneof_field("type")
+ enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type")
+ non_union_field: int = 0
+
+
+@dataclasses.dataclass
+class ClassWithTwoUnions(ir_data.Message):
+ """Used for testing data field helpers"""
+
+ opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1")
+ integer: Optional[int] = ir_data_fields.oneof_field("type_1")
+ boolean: Optional[bool] = ir_data_fields.oneof_field("type_2")
+ enumeration: Optional[TestEnum] = ir_data_fields.oneof_field("type_2")
+ non_union_field: int = 0
+ seq_field: list[int] = ir_data_fields.list_field(int)
+
+
+class IrDataUtilsTest(unittest.TestCase):
+ """Tests for the miscellaneous utility functions in ir_data_utils.py."""
+
+ def test_field_specs(self):
+ """Tests the `field_specs` method"""
+ fields = ir_data_utils.field_specs(ir_data.TypeDefinition)
+ self.assertIsNotNone(fields)
+ expected_fields = (
+ "external",
+ "enumeration",
+ "structure",
+ "name",
+ "attribute",
+ "documentation",
+ "subtype",
+ "addressable_unit",
+ "runtime_parameter",
+ "source_location",
+ )
+ self.assertEqual(len(fields), len(expected_fields))
+ field_names = fields.keys()
+ for k in expected_fields:
+ self.assertIn(k, field_names)
+
+ # Try a sequence
+ expected_field = ir_data_fields.make_field_spec(
+ "attribute", ir_data.Attribute, ir_data_fields.FieldContainer.LIST, None
+ )
+ self.assertEqual(fields["attribute"], expected_field)
+
+ # Try a scalar
+ expected_field = ir_data_fields.make_field_spec(
+ "addressable_unit",
+ ir_data.AddressableUnit,
+ ir_data_fields.FieldContainer.OPTIONAL,
+ None,
+ )
+ self.assertEqual(fields["addressable_unit"], expected_field)
+
+ # Try a IR data class
+ expected_field = ir_data_fields.make_field_spec(
+ "source_location",
+ ir_data.Location,
+ ir_data_fields.FieldContainer.OPTIONAL,
+ None,
+ )
+ self.assertEqual(fields["source_location"], expected_field)
+
+ # Try an oneof field
+ expected_field = ir_data_fields.make_field_spec(
+ "external",
+ ir_data.External,
+ ir_data_fields.FieldContainer.OPTIONAL,
+ oneof="type",
+ )
+ self.assertEqual(fields["external"], expected_field)
+
+ # Try non-optional scalar
+ fields = ir_data_utils.field_specs(ir_data.Position)
+ expected_field = ir_data_fields.make_field_spec(
+ "line", int, ir_data_fields.FieldContainer.NONE, None
+ )
+ self.assertEqual(fields["line"], expected_field)
+
+ fields = ir_data_utils.field_specs(ir_data.ArrayType)
+ expected_field = ir_data_fields.make_field_spec(
+ "base_type", ir_data.Type, ir_data_fields.FieldContainer.OPTIONAL, None
+ )
+ self.assertEqual(fields["base_type"], expected_field)
+
+ def test_is_sequence(self):
+ """Tests for the `FieldSpec.is_sequence` helper"""
+ type_def = ir_data.TypeDefinition(
+ attribute=[
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ ),
+ ]
+ )
+ fields = ir_data_utils.field_specs(ir_data.TypeDefinition)
+ # Test against a repeated field
+ self.assertTrue(fields["attribute"].is_sequence)
+ # Test against a nested IR data type
+ self.assertFalse(fields["name"].is_sequence)
+ # Test against a plain scalar type
+ fields = ir_data_utils.field_specs(type_def.attribute[0])
+ self.assertFalse(fields["is_default"].is_sequence)
+
+ def test_is_dataclass(self):
+ """Tests FieldSpec.is_dataclass against ir_data"""
+ type_def = ir_data.TypeDefinition(
+ attribute=[
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ ),
+ ]
+ )
+ fields = ir_data_utils.field_specs(ir_data.TypeDefinition)
+ # Test against a repeated field that holds IR data structs
+ self.assertTrue(fields["attribute"].is_dataclass)
+ # Test against a nested IR data type
+ self.assertTrue(fields["name"].is_dataclass)
+ # Test against a plain scalar type
+ fields = ir_data_utils.field_specs(type_def.attribute[0])
+ self.assertFalse(fields["is_default"].is_dataclass)
+ # Test against a repeated field that holds scalars
+ fields = ir_data_utils.field_specs(ir_data.Structure)
+ self.assertFalse(fields["fields_in_dependency_order"].is_dataclass)
+
+ def test_get_set_fields(self):
+ """Tests that get set fields works"""
+ type_def = ir_data.TypeDefinition(
+ attribute=[
+ ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ ),
+ ]
+ )
+ set_fields = ir_data_utils.get_set_fields(type_def)
+ expected_fields = set(
+ ["attribute", "documentation", "subtype", "runtime_parameter"]
+ )
+ self.assertEqual(len(set_fields), len(expected_fields))
+ found_fields = set()
+ for k, v in set_fields:
+ self.assertIn(k.name, expected_fields)
+ found_fields.add(k.name)
+ self.assertEqual(v, getattr(type_def, k.name))
+
+ self.assertSetEqual(found_fields, expected_fields)
+
+ def test_copy(self):
+ """Tests the `copy` helper"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+ attribute_copy = ir_data_utils.copy(attribute)
+
+ # Should be equivalent
+ self.assertEqual(attribute, attribute_copy)
+ # But not the same instance
+ self.assertIsNot(attribute, attribute_copy)
+
+ # Let's do a sequence
+ type_def = ir_data.TypeDefinition(attribute=[attribute])
+ type_def_copy = ir_data_utils.copy(type_def)
+
+ # Should be equivalent
+ self.assertEqual(type_def, type_def_copy)
+ # But not the same instance
+ self.assertIsNot(type_def, type_def_copy)
+ self.assertIsNot(type_def.attribute, type_def_copy.attribute)
+
+ def test_update(self):
+ """Tests the `update` helper"""
+ attribute_template = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+ attribute = ir_data.Attribute(is_default=True)
+ ir_data_utils.update(attribute, attribute_template)
+ self.assertIsNotNone(attribute.value)
+ self.assertIsNot(attribute.value, attribute_template.value)
+ self.assertIsNotNone(attribute.name)
+ self.assertIsNot(attribute.name, attribute_template.name)
+
+ # Value not present in template should be untouched
+ self.assertTrue(attribute.is_default)
+
+
+class IrDataBuilderTest(unittest.TestCase):
+ """Tests for IrDataBuilder"""
+
+ def test_ir_data_builder(self):
+ """Tests that basic builder chains work"""
+ # We start with an empty type
+ type_def = ir_data.TypeDefinition()
+ self.assertFalse(type_def.HasField("name"))
+ self.assertIsNone(type_def.name)
+
+ # Now setup a builder
+ builder = ir_data_utils.builder(type_def)
+
+ # Assign to a sub-child
+ builder.name.name = ir_data.Word(text="phil")
+
+ # Verify the wrapped struct is updated
+ self.assertIsNotNone(type_def.name)
+ self.assertIsNotNone(type_def.name.name)
+ self.assertIsNotNone(type_def.name.name.text)
+ self.assertEqual(type_def.name.name.text, "phil")
+
+ def test_ir_data_builder_bad_field(self):
+ """Tests accessing an undefined field name fails"""
+ type_def = ir_data.TypeDefinition()
+ builder = ir_data_utils.builder(type_def)
+ self.assertRaises(AttributeError, lambda: builder.foo)
+ # Make sure it's not set on our IR data class either
+ self.assertRaises(AttributeError, getattr, type_def, "foo")
+
+ def test_ir_data_builder_sequence(self):
+ """Tests that sequences are properly wrapped"""
+ # We start with an empty type
+ type_def = ir_data.TypeDefinition()
+ self.assertTrue(type_def.HasField("attribute"))
+ self.assertEqual(len(type_def.attribute), 0)
+
+ # Now setup a builder
+ builder = ir_data_utils.builder(type_def)
+
+ # Assign to a sequence
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+
+ builder.attribute.append(attribute)
+ self.assertEqual(builder.attribute, [attribute])
+ self.assertTrue(type_def.HasField("attribute"))
+ self.assertEqual(len(type_def.attribute), 1)
+ self.assertEqual(type_def.attribute[0], attribute)
+
+ # Lets make it longer and then try iterating
+ builder.attribute.append(attribute)
+ self.assertEqual(len(type_def.attribute), 2)
+ for attr in builder.attribute:
+ # Modify the attributes
+ attr.name.text = "bob"
+
+ # Make sure we can build up auto-default entries from a sequence item
+ builder.attribute.append(ir_data.Attribute())
+ builder.attribute[-1].value.expression = ir_data.Expression()
+ builder.attribute[-1].name.text = "bob"
+
+ # Create an attribute to compare against
+ new_attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="bob"),
+ )
+
+ self.assertEqual(len(type_def.attribute), 3)
+ for attr in type_def.attribute:
+ self.assertEqual(attr, new_attribute)
+
+ # Make sure the list type is a CopyValuesList
+ self.assertIsInstance(
+ type_def.attribute,
+ ir_data_fields.CopyValuesList,
+ f"Instance is: {type(type_def.attribute)}",
+ )
+
+ def test_copy_from(self) -> None:
+ """Tests that `CopyFrom` works."""
+ location = ir_data.Location(
+ start=ir_data.Position(line=1, column=1),
+ end=ir_data.Position(line=1, column=2),
+ )
+ expression_ir = ir_data.Expression(source_location=location)
+ template: ir_data.Expression = expression_parser.parse("x + y")
+ expression = ir_data_utils.builder(expression_ir)
+ expression.CopyFrom(template)
+ self.assertIsNotNone(expression_ir.function)
+ self.assertIsInstance(expression.function, ir_data_utils._IrDataBuilder)
+ self.assertIsInstance(
+ expression.function.args, ir_data_utils._IrDataSequenceBuilder
+ )
+ self.assertTrue(expression_ir.function.args)
+
+ def test_copy_from_list(self):
+ specs = ir_data_utils.field_specs(ir_data.Function)
+ args_spec = specs["args"]
+ self.assertTrue(args_spec.is_dataclass)
+ template: ir_data.Expression = expression_parser.parse("x + y")
+ self.assertIsNotNone(template)
+ self.assertIsInstance(template, ir_data.Expression)
+ self.assertIsInstance(template.function, ir_data.Function)
+ self.assertIsInstance(template.function.args, ir_data_fields.CopyValuesList)
+
+ location = ir_data.Location(
+ start=ir_data.Position(line=1, column=1),
+ end=ir_data.Position(line=1, column=2),
+ )
+ expression_ir = ir_data.Expression(source_location=location)
+ self.assertIsInstance(expression_ir, ir_data.Expression)
+ self.assertIsNone(expression_ir.function)
+
+ expression_builder = ir_data_utils.builder(expression_ir)
+ self.assertIsInstance(expression_builder, ir_data_utils._IrDataBuilder)
+ expression_builder.CopyFrom(template)
+ self.assertIsNotNone(expression_ir.function)
+ self.assertIsInstance(expression_ir.function, ir_data.Function)
+ self.assertIsNotNone(expression_ir.function.args)
+ self.assertIsInstance(
+ expression_ir.function.args, ir_data_fields.CopyValuesList
+ )
+
+ self.assertIsInstance(expression_builder, ir_data_utils._IrDataBuilder)
+ self.assertIsInstance(
+ expression_builder.function, ir_data_utils._IrDataBuilder
+ )
+ self.assertIsInstance(
+ expression_builder.function.args, ir_data_utils._IrDataSequenceBuilder
+ )
+
+ def test_ir_data_builder_sequence_scalar(self):
+ """Tests that sequences of scalars function properly"""
+ # We start with an empty type
+ structure = ir_data.Structure()
+
+ # Now setup a builder
+ builder = ir_data_utils.builder(structure)
+
+ # Assign to a scalar sequence
+ builder.fields_in_dependency_order.append(12)
+ builder.fields_in_dependency_order.append(11)
+
+ self.assertTrue(structure.HasField("fields_in_dependency_order"))
+ self.assertEqual(len(structure.fields_in_dependency_order), 2)
+ self.assertEqual(structure.fields_in_dependency_order[0], 12)
+ self.assertEqual(structure.fields_in_dependency_order[1], 11)
+ self.assertEqual(builder.fields_in_dependency_order, [12, 11])
+
+ new_structure = ir_data.Structure(fields_in_dependency_order=[12, 11])
+ self.assertEqual(structure, new_structure)
+
+ def test_ir_data_builder_oneof(self):
+ value = ir_data.AttributeValue(
+ expression=ir_data.Expression(
+ boolean_constant=ir_data.BooleanConstant()
+ )
+ )
+ builder = ir_data_utils.builder(value)
+ self.assertTrue(builder.HasField("expression"))
+ self.assertFalse(builder.expression.boolean_constant.value)
+ builder.expression.boolean_constant.value = True
+ self.assertTrue(builder.expression.boolean_constant.value)
+ self.assertTrue(value.expression.boolean_constant.value)
+
+ bool_constant = value.expression.boolean_constant
+ self.assertIsInstance(bool_constant, ir_data.BooleanConstant)
+
+
+class IrDataSerializerTest(unittest.TestCase):
+ """Tests for IrDataSerializer"""
+
+ def test_ir_data_serializer_to_dict(self):
+ """Tests serialization with `IrDataSerializer.to_dict` with default settings"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+
+ serializer = ir_data_utils.IrDataSerializer(attribute)
+ raw_dict = serializer.to_dict()
+ expected = {
+ "name": {"text": "phil", "source_location": None},
+ "value": {
+ "expression": {
+ "constant": None,
+ "constant_reference": None,
+ "function": None,
+ "field_reference": None,
+ "boolean_constant": None,
+ "builtin_reference": None,
+ "type": None,
+ "source_location": None,
+ },
+ "string_constant": None,
+ "source_location": None,
+ },
+ "back_end": None,
+ "is_default": None,
+ "source_location": None,
+ }
+ self.assertDictEqual(raw_dict, expected)
+
+ def test_ir_data_serializer_to_dict_exclude_none(self):
+ """Tests serialization with `IrDataSerializer.to_dict` when excluding None values"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+ serializer = ir_data_utils.IrDataSerializer(attribute)
+ raw_dict = serializer.to_dict(exclude_none=True)
+ expected = {"name": {"text": "phil"}, "value": {"expression": {}}}
+ self.assertDictEqual(raw_dict, expected)
+
+ def test_ir_data_serializer_to_dict_enum(self):
+ """Tests that serialization of `enum.Enum` values works properly"""
+ type_def = ir_data.TypeDefinition(
+ addressable_unit=ir_data.AddressableUnit.BYTE
+ )
+ serializer = ir_data_utils.IrDataSerializer(type_def)
+ raw_dict = serializer.to_dict(exclude_none=True)
+ expected = {"addressable_unit": ir_data.AddressableUnit.BYTE}
+ self.assertDictEqual(raw_dict, expected)
+
+ def test_ir_data_serializer_from_dict(self):
+ """Tests deserializing IR data from a serialized dict"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+ serializer = ir_data_utils.IrDataSerializer(attribute)
+ raw_dict = serializer.to_dict(exclude_none=False)
+ new_attribute = serializer.from_dict(ir_data.Attribute, raw_dict)
+ self.assertEqual(attribute, new_attribute)
+
+ def test_ir_data_serializer_from_dict_enum(self):
+ """Tests that deserializing `enum.Enum` values works properly"""
+ type_def = ir_data.TypeDefinition(
+ addressable_unit=ir_data.AddressableUnit.BYTE
+ )
+
+ serializer = ir_data_utils.IrDataSerializer(type_def)
+ raw_dict = serializer.to_dict(exclude_none=False)
+ new_type_def = serializer.from_dict(ir_data.TypeDefinition, raw_dict)
+ self.assertEqual(type_def, new_type_def)
+
+ def test_ir_data_serializer_from_dict_enum_is_str(self):
+ """Tests that deserializing `enum.Enum` values works properly when string constant is used"""
+ type_def = ir_data.TypeDefinition(
+ addressable_unit=ir_data.AddressableUnit.BYTE
+ )
+ raw_dict = {"addressable_unit": "BYTE"}
+ serializer = ir_data_utils.IrDataSerializer(type_def)
+ new_type_def = serializer.from_dict(ir_data.TypeDefinition, raw_dict)
+ self.assertEqual(type_def, new_type_def)
+
+ def test_ir_data_serializer_from_dict_exclude_none(self):
+ """Tests that deserializing from a dict that excluded None values works properly"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+
+ serializer = ir_data_utils.IrDataSerializer(attribute)
+ raw_dict = serializer.to_dict(exclude_none=True)
+ new_attribute = ir_data_utils.IrDataSerializer.from_dict(
+ ir_data.Attribute, raw_dict
+ )
+ self.assertEqual(attribute, new_attribute)
+
+ def test_from_dict_list(self):
+ function_args = [
+ {
+ "constant": {
+ "value": "0",
+ "source_location": {
+ "start": {"line": 421, "column": 3},
+ "end": {"line": 421, "column": 4},
+ "is_synthetic": False,
+ },
+ },
+ "type": {
+ "integer": {
+ "modulus": "infinity",
+ "modular_value": "0",
+ "minimum_value": "0",
+ "maximum_value": "0",
+ }
+ },
+ "source_location": {
+ "start": {"line": 421, "column": 3},
+ "end": {"line": 421, "column": 4},
+ "is_synthetic": False,
+ },
+ },
+ {
+ "constant": {
+ "value": "1",
+ "source_location": {
+ "start": {"line": 421, "column": 11},
+ "end": {"line": 421, "column": 12},
+ "is_synthetic": False,
+ },
+ },
+ "type": {
+ "integer": {
+ "modulus": "infinity",
+ "modular_value": "1",
+ "minimum_value": "1",
+ "maximum_value": "1",
+ }
+ },
+ "source_location": {
+ "start": {"line": 421, "column": 11},
+ "end": {"line": 421, "column": 12},
+ "is_synthetic": False,
+ },
+ },
+ ]
+ function_data = {"args": function_args}
+ func = ir_data_utils.IrDataSerializer.from_dict(
+ ir_data.Function, function_data
+ )
+ self.assertIsNotNone(func)
+
+ def test_ir_data_serializer_copy_from_dict(self):
+ """Tests that updating an IR data struct from a dict works properly"""
+ attribute = ir_data.Attribute(
+ value=ir_data.AttributeValue(expression=ir_data.Expression()),
+ name=ir_data.Word(text="phil"),
+ )
+ serializer = ir_data_utils.IrDataSerializer(attribute)
+ raw_dict = serializer.to_dict(exclude_none=False)
+
+ new_attribute = ir_data.Attribute()
+ new_serializer = ir_data_utils.IrDataSerializer(new_attribute)
+ new_serializer.copy_from_dict(raw_dict)
+ self.assertEqual(attribute, new_attribute)
+
+
+class ReadOnlyFieldCheckerTest(unittest.TestCase):
+ """Tests the ReadOnlyFieldChecker"""
+
+ def test_basic_wrapper(self):
+ """Tests basic field checker actions"""
+ union = ClassWithTwoUnions(
+ opaque=Opaque(), boolean=True, non_union_field=10
+ )
+ field_checker = ir_data_utils.reader(union)
+
+ # All accesses should return a wrapper object
+ self.assertIsNotNone(field_checker.opaque)
+ self.assertIsNotNone(field_checker.integer)
+ self.assertIsNotNone(field_checker.boolean)
+ self.assertIsNotNone(field_checker.enumeration)
+ self.assertIsNotNone(field_checker.non_union_field)
+ # Scalar field should pass through
+ self.assertEqual(field_checker.non_union_field, 10)
+
+ # Make sure HasField works
+ self.assertTrue(field_checker.HasField("opaque"))
+ self.assertFalse(field_checker.HasField("integer"))
+ self.assertTrue(field_checker.HasField("boolean"))
+ self.assertFalse(field_checker.HasField("enumeration"))
+ self.assertTrue(field_checker.HasField("non_union_field"))
+
+ def test_construct_from_field_checker(self):
+ """Tests that constructing from another field checker works"""
+ union = ClassWithTwoUnions(
+ opaque=Opaque(), boolean=True, non_union_field=10
+ )
+ field_checker_orig = ir_data_utils.reader(union)
+ field_checker = ir_data_utils.reader(field_checker_orig)
+ self.assertIsNotNone(field_checker)
+ self.assertEqual(field_checker.ir_or_spec, union)
+
+ # All accesses should return a wrapper object
+ self.assertIsNotNone(field_checker.opaque)
+ self.assertIsNotNone(field_checker.integer)
+ self.assertIsNotNone(field_checker.boolean)
+ self.assertIsNotNone(field_checker.enumeration)
+ self.assertIsNotNone(field_checker.non_union_field)
+ # Scalar field should pass through
+ self.assertEqual(field_checker.non_union_field, 10)
+
+ # Make sure HasField works
+ self.assertTrue(field_checker.HasField("opaque"))
+ self.assertFalse(field_checker.HasField("integer"))
+ self.assertTrue(field_checker.HasField("boolean"))
+ self.assertFalse(field_checker.HasField("enumeration"))
+ self.assertTrue(field_checker.HasField("non_union_field"))
+
+ def test_read_only(self) -> None:
+ """Tests that the read only wrapper really is read only"""
+ union = ClassWithTwoUnions(
+ opaque=Opaque(), boolean=True, non_union_field=10
+ )
+ field_checker = ir_data_utils.reader(union)
+
+ def set_field():
+ field_checker.opaque = None
+
+ self.assertRaises(AttributeError, set_field)
+
+
+ir_data_fields.cache_message_specs(
+ sys.modules[ReadOnlyFieldCheckerTest.__module__], ir_data.Message)
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/compiler/util/ir_util.py b/compiler/util/ir_util.py
index 02dc647..603c0c0 100644
--- a/compiler/util/ir_util.py
+++ b/compiler/util/ir_util.py
@@ -25,6 +25,8 @@
def get_attribute(attribute_list, name):
"""Finds name in attribute_list and returns a AttributeValue or None."""
+ if not attribute_list:
+ return None
attribute_value = None
for attr in attribute_list:
if attr.name.text == name and not attr.is_default:
@@ -88,9 +90,11 @@
def constant_value(expression, bindings=None):
"""Evaluates expression with the given bindings."""
+ if expression is None:
+ return None
expression = ir_data_utils.reader(expression)
if expression.WhichOneof("expression") == "constant":
- return int(expression.constant.value)
+ return int(expression.constant.value or 0)
elif expression.WhichOneof("expression") == "constant_reference":
# We can't look up the constant reference without the IR, but by the time
# constant_value is called, the actual values should have been propagated to
@@ -252,7 +256,7 @@
def _find_path_in_parameters(path, type_definition):
- if len(path) > 1:
+ if len(path) > 1 or not type_definition.HasField("runtime_parameter"):
return None
for parameter in type_definition.runtime_parameter:
if ir_data_utils.reader(parameter).name.name.text == path[0]:
@@ -274,7 +278,7 @@
if obj:
return obj
else:
- return _find_path_in_type_list(path, type_definition.subtype)
+ return _find_path_in_type_list(path, type_definition.subtype or [])
def _find_path_in_type_list(path, type_list):
diff --git a/compiler/util/test_util.py b/compiler/util/test_util.py
index 0d33600..02ac9a3 100644
--- a/compiler/util/test_util.py
+++ b/compiler/util/test_util.py
@@ -14,7 +14,7 @@
"""Utilities for test code."""
-from compiler.util import ir_data
+from compiler.util import ir_data_utils
def proto_is_superset(proto, expected_values, path=""):
@@ -47,11 +47,12 @@
"""
if path:
path += "."
- for name, expected_value in expected_values.raw_fields.items():
+ for spec, expected_value in ir_data_utils.get_set_fields(expected_values):
+ name = spec.name
field_path = "{}{}".format(path, name)
value = getattr(proto, name)
- if issubclass(proto.field_specs[name].type, ir_data.Message):
- if isinstance(proto.field_specs[name], ir_data.Repeated):
+ if spec.is_dataclass:
+ if spec.is_sequence:
if len(expected_value) > len(value):
return False, "{}[{}] missing".format(field_path,
len(getattr(proto, name)))
@@ -71,9 +72,9 @@
# Zero-length repeated fields and not-there repeated fields are "the
# same."
if (expected_value != value and
- (isinstance(proto.field_specs[name], ir_data.Optional) or
+ (not spec.is_sequence or
len(expected_value))):
- if isinstance(proto.field_specs[name], ir_data.Repeated):
+ if spec.is_sequence:
return False, "{} differs: found {}, expected {}".format(
field_path, list(value), list(expected_value))
else:
diff --git a/compiler/util/traverse_ir.py b/compiler/util/traverse_ir.py
index 78efe46..93ffdc8 100644
--- a/compiler/util/traverse_ir.py
+++ b/compiler/util/traverse_ir.py
@@ -17,6 +17,8 @@
import inspect
from compiler.util import ir_data
+from compiler.util import ir_data_fields
+from compiler.util import ir_data_utils
from compiler.util import simple_memoizer
@@ -157,7 +159,7 @@
incidental_actions, new_pattern,
skip_descendants_of, action, parameters)
for member_name in repeated_fields:
- for array_element in getattr(proto, member_name):
+ for array_element in getattr(proto, member_name) or []:
_fast_traverse_proto_top_down(array_element, incidental_actions,
new_pattern, skip_descendants_of, action,
parameters)
@@ -192,12 +194,12 @@
if type_to_check in type_to_fields:
continue
fields = {}
- for field_name, field_type in type_to_check.field_specs.items():
- if issubclass(field_type.type, ir_data.Message):
- fields[field_name] = field_type.type
- types_to_check.append(field_type.type)
+ for field_name, field_type in ir_data_utils.field_specs(type_to_check).items():
+ if field_type.is_dataclass:
+ fields[field_name] = field_type.data_type
+ types_to_check.append(field_type.data_type)
type_fields_to_cardinality[type_to_check, field_name] = (
- field_type.__class__)
+ field_type.container)
type_to_fields[type_to_check] = fields
# type_to_descendant_types is a map of all types that can be reached from a
@@ -239,8 +241,8 @@
target_node_type in type_to_descendant_types[field_type]):
# Singular and repeated fields go to different lists, so that they can
# be handled separately.
- if (type_fields_to_cardinality[current_node_type, field_name] ==
- ir_data.Optional):
+ if (type_fields_to_cardinality[current_node_type, field_name] is not
+ ir_data_fields.FieldContainer.LIST):
singular_fields_to_scan.append(field_name)
else:
repeated_fields_to_scan.append(field_name)