Merge branch 'master' into re_fullmatch
diff --git a/compiler/back_end/cpp/header_generator.py b/compiler/back_end/cpp/header_generator.py
index e599585..6a66461 100644
--- a/compiler/back_end/cpp/header_generator.py
+++ b/compiler/back_end/cpp/header_generator.py
@@ -188,7 +188,7 @@
def _get_namespace_components(namespace):
- """Gets the components of a C++ namespace
+ """Gets the components of a C++ namespace.
Examples:
"::some::name::detail" -> ["some", "name", "detail"]
@@ -1339,10 +1339,12 @@
Arguments:
type_ir: The IR for the struct definition.
ir: The full IR; used for type lookups.
+ config: The code generation configuration to use.
Returns:
- A tuple of: (forward declaration for classes, class bodies, method bodies),
- suitable for insertion into the appropriate places in the generated header.
+ A tuple of: (forward declaration for classes, class bodies, method
+ bodies), suitable for insertion into the appropriate places in the
+ generated header.
"""
subtype_bodies, subtype_forward_declarations, subtype_method_definitions = (
_generate_subtype_definitions(type_ir, ir, config)
@@ -1530,10 +1532,17 @@
def _split_enum_case_values_into_spans(enum_case_value):
"""Yields spans containing each enum case in an enum_case attribute value.
- Each span is of the form (start, end), which is the start and end position
- relative to the beginning of the enum_case_value string. To keep the grammar
- of this attribute simple, this only splits on delimiters and trims whitespace
- for each case.
+ Arguments:
+ enum_case_value: the value of the `enum_case` attribute to be parsed.
+
+ Returns:
+ An iterator over spans, where each span covers one enum case name.
+ Each span is a half-open range of the form [start, end), which is the
+ start and end position relative to the beginning of the enum_case_value
+ string. The name can be retrieved with `enum_case_value[start:end]`.
+
+ To keep the grammar of this attribute simple, this only splits on
+ delimiters and trims whitespace for each case.
Example: 'SHOUTY_CASE, kCamelCase' -> [(0, 11), (13, 23)]"""
# Scan the string from left to right, finding commas and trimming whitespace.
@@ -1568,6 +1577,12 @@
def _split_enum_case_values(enum_case_value):
"""Returns all enum cases in an enum case value.
+ Arguments:
+ enum_case_value: the value of the enum case attribute to parse.
+
+ Returns:
+ All enum case names from `enum_case_value`.
+
Example: 'SHOUTY_CASE, kCamelCase' -> ['SHOUTY_CASE', 'kCamelCase']"""
return [
enum_case_value[start:end]
@@ -1576,7 +1591,7 @@
def _get_enum_value_names(enum_value):
- """Determines one or more enum names based on attributes"""
+ """Determines one or more enum names based on attributes."""
cases = ["SHOUTY_CASE"]
name = enum_value.name.name.text
if enum_case := ir_util.get_attribute(
@@ -1698,14 +1713,16 @@
Traverses the IR to propagate default values to target nodes.
Arguments:
- targets: A list of target IR types to add attributes to.
- ancestors: Ancestor types which may contain the default values.
- add_fn: Function to add the attribute. May use any parameter available in
- fast_traverse_ir_top_down actions as well as `defaults` containing the
- default attributes set by ancestors.
+ ir: The IR to process.
+ targets: A list of target IR types to add attributes to.
+ ancestors: Ancestor types which may contain the default values.
+ add_fn: Function to add the attribute. May use any parameter available
+ in fast_traverse_ir_top_down actions as well as `defaults`
+ containing the
+ default attributes set by ancestors.
Returns:
- None
+ None
"""
traverse_ir.fast_traverse_ir_top_down(
ir,
@@ -1719,14 +1736,19 @@
def _offset_source_location_column(source_location, offset):
- """Adds offsets from the start column of the supplied source location
+ """Adds offsets from the start column of the supplied source location.
- Returns a new source location with all of the same properties as the provided
- source location, but with the columns modified by offsets from the original
- start column.
+ Arguments:
+ source_location: the initial source location
+ offset: a tuple of (start, end), which are the offsets relative to
+ source_location.start.column to set the new start.column and
+ end.column.
- Offset should be a tuple of (start, end), which are the offsets relative to
- source_location.start.column to set the new start.column and end.column."""
+ Returns:
+ A new source location with all of the same properties as the provided
+ source location, but with the columns modified by offsets from the
+ original start column.
+ """
new_location = ir_data_utils.copy(source_location)
new_location.start.column = source_location.start.column + offset[0]
@@ -1863,8 +1885,12 @@
def _propagate_defaults_and_verify_attributes(ir):
"""Verify attributes and ensure defaults are set when not overridden.
- Returns a list of errors if there are errors present, or an empty list if
- verification completed successfully."""
+ Arguments:
+ ir: The IR to process.
+
+ Returns:
+ A list of errors if there are errors present, or an empty list if
+ verification completed successfully."""
if errors := attribute_util.check_attributes_in_ir(
ir,
back_end="cpp",
diff --git a/compiler/back_end/util/code_template_test.py b/compiler/back_end/util/code_template_test.py
index e4354fe..8149aaa 100644
--- a/compiler/back_end/util/code_template_test.py
+++ b/compiler/back_end/util/code_template_test.py
@@ -55,7 +55,7 @@
"""Tests for code_template.parse_templates."""
def assertTemplatesEqual(self, expected, actual): # pylint:disable=invalid-name
- """Compares the results of a parse_templates"""
+ """Compares the results of a parse_templates."""
# Extract the name and template from the result tuple
actual = {k: v.template for k, v in actual._asdict().items()}
self.assertEqual(expected, actual)
diff --git a/compiler/front_end/attribute_checker.py b/compiler/front_end/attribute_checker.py
index 78a2a1e..9c867d2 100644
--- a/compiler/front_end/attribute_checker.py
+++ b/compiler/front_end/attribute_checker.py
@@ -44,6 +44,7 @@
def _valid_back_ends(attr, module_source_file):
+ """Checks that `attr` holds a valid list of back end specifiers."""
if not re.fullmatch(
r"(?:\s*[a-z][a-z0-9_]*\s*(?:,\s*[a-z][a-z0-9_]*\s*)*,?)?\s*",
attr.value.string_constant.text,
diff --git a/compiler/front_end/constraints.py b/compiler/front_end/constraints.py
index 3249e6d..852c9ad 100644
--- a/compiler/front_end/constraints.py
+++ b/compiler/front_end/constraints.py
@@ -436,6 +436,7 @@
def _check_allowed_in_bits(type_ir, type_definition, source_file_name, ir, errors):
+ """Verifies that atomic fields have types that are allowed in `bits`."""
if not type_ir.HasField("atomic_type"):
return
referenced_type_definition = ir_util.find_object(type_ir.atomic_type.reference, ir)
diff --git a/compiler/front_end/dependency_checker.py b/compiler/front_end/dependency_checker.py
index 8a9e903..fb622e1 100644
--- a/compiler/front_end/dependency_checker.py
+++ b/compiler/front_end/dependency_checker.py
@@ -23,14 +23,15 @@
def _add_reference_to_dependencies(
reference, dependencies, name, source_file_name, errors
):
+ """Adds the specified `reference` to the `dependencies` set."""
if reference.canonical_name.object_path[0] in {
"$is_statically_sized",
"$static_size_in_bits",
"$next",
}:
- # This error is a bit opaque, but given that the compiler used to crash on
- # this case -- for a couple of years -- and no one complained, it seems
- # safe to assume that this is a rare error.
+ # This error is a bit opaque, but given that the compiler used to crash
+ # on this case -- for a couple of years -- and no one complained, it
+ # seems safe to assume that this is a rare error.
errors.append(
[
error.error(
diff --git a/compiler/front_end/format_emb.py b/compiler/front_end/format_emb.py
index df9bcf3..fc7bf94 100644
--- a/compiler/front_end/format_emb.py
+++ b/compiler/front_end/format_emb.py
@@ -485,6 +485,7 @@
" type-definition* struct-field-block Dedent"
)
def _structure_body(indent, docs, attributes, type_definitions, fields, dedent, config):
+ """Formats a structure (`bits` or `struct`) body."""
del indent, dedent # Unused.
spacing = [_Row("field-separator")] if _should_add_blank_lines(fields) else []
columnized_fields = _columnize(fields, config.indent_width, indent_columns=2)
@@ -609,6 +610,7 @@
' field-location "bits" ":" Comment? eol anonymous-bits-body'
)
def _inline_bits(location, bits, colon, comment, eol, body):
+ """Formats an inline `bits` definition."""
# Even though an anonymous bits field technically defines a new, anonymous
# type, conceptually it's more like defining a bunch of fields on the
# surrounding type, so it is treated as an inline list of blocks, instead of
@@ -1017,6 +1019,7 @@
@_formats("or-expression-right -> or-operator comparison-expression")
@_formats("and-expression-right -> and-operator comparison-expression")
def _concatenate_with_prefix_spaces(*elements):
+ """Concatenates non-empty `elements` with leading spaces."""
return "".join(" " + element for element in elements if element)
@@ -1032,10 +1035,12 @@
)
@_formats('parameter-definition-list-tail -> "," parameter-definition')
def _concatenate_with_spaces(*elements):
+ """Concatenates non-empty `elements` with spaces between."""
return _concatenate_with(" ", *elements)
def _concatenate_with(joiner, *elements):
+ """Concatenates non-empty `elements` with `joiner` between."""
return joiner.join(element for element in elements if element)
diff --git a/compiler/front_end/lr1.py b/compiler/front_end/lr1.py
index be99f95..579d729 100644
--- a/compiler/front_end/lr1.py
+++ b/compiler/front_end/lr1.py
@@ -36,16 +36,17 @@
):
"""An Item is an LR(1) Item: a production, a cursor location, and a terminal.
- An Item represents a partially-parsed production, and a lookahead symbol. The
- position of the dot indicates what portion of the production has been parsed.
- Generally, Items are an internal implementation detail, but they can be useful
- elsewhere, particularly for debugging.
+ An Item represents a partially-parsed production, and a lookahead symbol.
+ The position of the dot indicates what portion of the production has been
+ parsed. Generally, Items are an internal implementation detail, but they
+ can be useful elsewhere, particularly for debugging.
Attributes:
- production: The Production this Item covers.
- dot: The index of the "dot" in production's rhs.
- terminal: The terminal lookahead symbol that follows the production in the
- input stream.
+ production: The Production this Item covers.
+ dot: The index of the "dot" in production's rhs.
+ terminal: The terminal lookahead symbol that follows the production in
+ the input stream.
+ next_symbol: The lookahead symbol.
"""
def __str__(self):
diff --git a/compiler/front_end/module_ir.py b/compiler/front_end/module_ir.py
index bd27c8a..4a459c2 100644
--- a/compiler/front_end/module_ir.py
+++ b/compiler/front_end/module_ir.py
@@ -339,6 +339,7 @@
attribute_value,
close_bracket,
):
+ """Assembles an attribute IR node."""
del open_bracket, colon, close_bracket # Unused.
if context_specifier.list:
return ir_data.Attribute(
@@ -460,14 +461,15 @@
' ":" logical-expression'
)
def _choice_expression(condition, question, if_true, colon, if_false):
+ """Constructs an IR node for a choice operator (`?:`) expression."""
location = parser_types.make_location(
condition.source_location.start, if_false.source_location.end
)
operator_location = parser_types.make_location(
question.source_location.start, colon.source_location.end
)
- # The function_name is a bit weird, but should suffice for any error messages
- # that might need it.
+ # The function_name is a bit weird, but should suffice for any error
+ # messages that might need it.
return ir_data.Expression(
function=ir_data.Function(
function=ir_data.FunctionMapping.CHOICE,
@@ -1284,6 +1286,7 @@
def _enum_value(
name, equals, expression, attribute, documentation, comment, newline, body
):
+ """Constructs an IR node for an enum value statement (`NAME = value`)."""
del equals, comment, newline # Unused.
result = ir_data.EnumValue(
name=name,
diff --git a/compiler/front_end/synthetics.py b/compiler/front_end/synthetics.py
index 1b331a3..f55e32f 100644
--- a/compiler/front_end/synthetics.py
+++ b/compiler/front_end/synthetics.py
@@ -236,6 +236,7 @@
def _maybe_replace_next_keyword_in_expression(
expression_ir, last_location, source_file_name, errors
):
+ """Replaces the `$next` keyword in an expression."""
if not expression_ir.HasField("builtin_reference"):
return
if (
diff --git a/compiler/front_end/type_check.py b/compiler/front_end/type_check.py
index f562cc8..c3f7870 100644
--- a/compiler/front_end/type_check.py
+++ b/compiler/front_end/type_check.py
@@ -133,6 +133,7 @@
def _type_check_operation(expression, source_file_name, ir, errors):
+ """Type checks a function or operator expression."""
for arg in expression.function.args:
_type_check_expression(arg, source_file_name, ir, errors)
function = expression.function.function
diff --git a/compiler/util/attribute_util.py b/compiler/util/attribute_util.py
index a83cf51..2ebbffa 100644
--- a/compiler/util/attribute_util.py
+++ b/compiler/util/attribute_util.py
@@ -323,22 +323,24 @@
):
"""Performs basic checks on the given list of attributes.
- Checks the given attribute_list for duplicates, unknown attributes, attributes
- with incorrect type, and attributes whose values are not constant.
+ Checks the given attribute_list for duplicates, unknown attributes,
+ attributes with incorrect type, and attributes whose values are not
+ constant.
Arguments:
- attribute_list: An iterable of ir_data.Attribute.
- back_end: The qualifier for attributes to check, or None.
- attribute_specs: A dict of attribute names to _Attribute structures
- specifying the allowed attributes.
- context_name: A name for the context of these attributes, such as "struct
- 'Foo'" or "module 'm.emb'". Used in error messages.
- module_source_file: The value of module.source_file_name from the module
- containing 'attribute_list'. Used in error messages.
+ attribute_list: An iterable of ir_data.Attribute.
+ types: A map of attribute types to validators.
+ back_end: The qualifier for attributes to check, or None.
+ attribute_specs: A dict of attribute names to _Attribute structures
+ specifying the allowed attributes.
+ context_name: A name for the context of these attributes, such as
+ "struct 'Foo'" or "module 'm.emb'". Used in error messages.
+ module_source_file: The value of module.source_file_name from the module
+ containing 'attribute_list'. Used in error messages.
Returns:
- A list of lists of error.Errors. An empty list indicates no errors were
- found.
+ A list of lists of error.Errors. An empty list indicates no errors were
+ found.
"""
if attribute_specs is None:
attribute_specs = []
@@ -395,17 +397,17 @@
def gather_default_attributes(obj, defaults):
- """Gathers default attributes for an IR object
+ """Gathers default attributes for an IR object.
- This is designed to be able to be used as-is as an incidental action in an IR
- traversal to accumulate defaults for child nodes.
+ This is designed to be able to be used as-is as an incidental action in an
+ IR traversal to accumulate defaults for child nodes.
Arguments:
- defaults: A dict of `{ "defaults": { attr.name.text: attr } }`
+ defaults: A dict of `{ "defaults": { attr.name.text: attr } }`
Returns:
- A dict of `{ "defaults": { attr.name.text: attr } }` with any defaults
- provided by `obj` added/overridden.
+ A dict of `{ "defaults": { attr.name.text: attr } }` with any defaults
+ provided by `obj` added/overridden.
"""
defaults = defaults.copy()
for attr in obj.attribute:
diff --git a/compiler/util/ir_data.py b/compiler/util/ir_data.py
index af8c2f7..fda124e 100644
--- a/compiler/util/ir_data.py
+++ b/compiler/util/ir_data.py
@@ -110,7 +110,10 @@
def WhichOneof(self, oneof_name): # pylint:disable=invalid-name
"""Indicates which field has been set for the oneof value.
- Returns None if no field has been set.
+ Args:
+ oneof_name: the name of the oneof construct to test.
+
+ Returns: the field name, or None if no field has been set.
"""
for field_name, oneof in self.field_specs.oneof_mappings:
if oneof == oneof_name and self.HasField(field_name):
@@ -207,7 +210,7 @@
class FunctionMapping(int, enum.Enum):
- """Enum of supported function types"""
+ """Enum of supported function types."""
UNKNOWN = 0
ADDITION = 1
@@ -823,8 +826,9 @@
class AddressableUnit(int, enum.Enum):
- """The "addressable unit" is the size of the smallest unit that can be read
+ """The 'atom size' for a structure.
+ The "addressable unit" is the size of the smallest unit that can be read
from the backing store that this type expects. For `struct`s, this is
BYTE; for `enum`s and `bits`, this is BIT, and for `external`s it depends
on the specific type
diff --git a/compiler/util/ir_data_fields.py b/compiler/util/ir_data_fields.py
index 002ea3b..248518b 100644
--- a/compiler/util/ir_data_fields.py
+++ b/compiler/util/ir_data_fields.py
@@ -76,10 +76,10 @@
class CopyValuesList(list[CopyValuesListT]):
- """A list that makes copies of any value that is inserted"""
+ """A list that makes copies of any value that is inserted."""
def __init__(
- self, value_type: CopyValuesListT, iterable: Optional[Iterable] = None
+ self, value_type: CopyValuesListT, iterable: Optional[Iterable[Any]] = None
):
if iterable:
super().__init__(iterable)
@@ -96,7 +96,7 @@
return super().extend([self._copy(i) for i in iterable])
def shallow_copy(self, iterable: Iterable) -> None:
- """Explicitly performs a shallow copy of the provided list"""
+ """Explicitly performs a shallow copy of the provided list."""
return super().extend(iterable)
def append(self, obj: Any) -> None:
@@ -107,15 +107,13 @@
class TemporaryCopyValuesList(NamedTuple):
- """Class used to temporarily hold a CopyValuesList while copying and
- constructing an IR dataclass.
- """
+ """Holder for a CopyValuesList while copying/constructing an IR dataclass."""
temp_list: CopyValuesList
class FieldContainer(enum.Enum):
- """Indicates a fields container type"""
+ """Indicates a fields container type."""
NONE = 0
OPTIONAL = 1
@@ -125,8 +123,8 @@
class FieldSpec(NamedTuple):
"""Indicates the container and type of a field.
- `FieldSpec` objects are accessed millions of times during runs so we cache as
- many operations as possible.
+ `FieldSpec` objects are accessed millions of times during runs so we cache
+ as many operations as possible.
- `is_dataclass`: `dataclasses.is_dataclass(data_type)`
- `is_sequence`: `container is FieldContainer.LIST`
- `is_enum`: `issubclass(data_type, enum.Enum)`
@@ -162,7 +160,7 @@
def build_default(field_spec: FieldSpec):
- """Builds a default instance of the given field"""
+ """Builds a default instance of the given field."""
if field_spec.is_sequence:
return CopyValuesList(field_spec.data_type)
if field_spec.is_enum:
@@ -216,11 +214,20 @@
def cache_message_specs(mod, cls):
- """Adds a cached `field_specs` attribute to IR dataclasses in `mod`
- excluding the given base `cls`.
+ """Adds `field_specs` to `mod`, excluding `cls`.
+
+ Adds a cached `field_specs` attribute to IR dataclasses in module `mod`
+ excluding the given base class `cls`.
This needs to be done after the dataclass decorators run and create the
wrapped classes.
+
+ Arguments:
+ mod: The module to process.
+ cls: The base class to exclude.
+
+ Returns:
+ None
"""
for data_class in all_ir_classes(mod):
if data_class is not cls:
@@ -228,7 +235,7 @@
def _field_specs(cls: type[IrDataT]) -> Mapping[str, FieldSpec]:
- """Gets the IR data field names and types for the given IR data class"""
+ """Gets the IR data field names and types for the given IR data class."""
# Get the dataclass fields
class_fields = dataclasses.fields(cast(Any, cls))
@@ -287,6 +294,12 @@
"""Retrieves the fields specs for the the give data type.
The results of this method are cached to reduce lookup overhead.
+
+ Arguments:
+ obj: Either an IR dataclass type, or an instance of such a type.
+
+ Returns:
+ The field specs for `obj`.
"""
cls = obj if isinstance(obj, type) else type(obj)
if cls is type(None):
@@ -301,8 +314,11 @@
"""Retrieves the fields and their values for a given IR data class.
Args:
- ir: The IR data class or a read-only wrapper of an IR data class.
- value_filt: Optional filter used to exclude values.
+ ir: The IR data class or a read-only wrapper of an IR data class.
+ value_filt: Optional filter used to exclude values.
+
+ Returns:
+ None
"""
set_fields: list[Tuple[FieldSpec, Any]] = []
specs: FilteredIrFieldSpecs = ir.field_specs
@@ -329,6 +345,7 @@
# 5. None checks are only done in `copy()`, `_copy_set_fields` only
# references `_copy()` to avoid this step.
def _copy_set_fields(ir: IrDataT):
+ """Deep copies fields from IR node `ir`."""
values: MutableMapping[str, Any] = {}
specs: FilteredIrFieldSpecs = ir.field_specs
@@ -355,7 +372,7 @@
def copy(ir: IrDataT) -> Optional[IrDataT]:
- """Creates a copy of the given IR data class"""
+ """Creates a copy of the given IR data class."""
if not ir:
return None
return _copy(ir)
@@ -413,14 +430,14 @@
def oneof_field(name: str):
- """Alternative for `datclasses.field` that sets up a oneof variable"""
+ """Alternative for `datclasses.field` that sets up a oneof variable."""
return dataclasses.field( # pylint:disable=invalid-field-call
default=OneOfField(name), metadata={"oneof": name}, init=True
)
def str_field():
- """Helper used to define a defaulted str field"""
+ """Helper used to define a defaulted str field."""
return dataclasses.field(default_factory=str) # pylint:disable=invalid-field-call
@@ -436,7 +453,10 @@
```
Args:
- cls_or_fn: The class type or a function that resolves to the class type.
+ cls_or_fn: The class type or a function that resolves to the class type.
+
+ Returns:
+ A field with a `default_factory` that produces an appropriate list.
"""
def list_factory(c):
diff --git a/compiler/util/ir_data_fields_test.py b/compiler/util/ir_data_fields_test.py
index 2853fde..103c1ad 100644
--- a/compiler/util/ir_data_fields_test.py
+++ b/compiler/util/ir_data_fields_test.py
@@ -34,12 +34,12 @@
@dataclasses.dataclass
class Opaque(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers."""
@dataclasses.dataclass
class ClassWithUnion(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers."""
opaque: Optional[Opaque] = ir_data_fields.oneof_field("type")
integer: Optional[int] = ir_data_fields.oneof_field("type")
@@ -50,7 +50,7 @@
@dataclasses.dataclass
class ClassWithTwoUnions(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers."""
opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1")
integer: Optional[int] = ir_data_fields.oneof_field("type_1")
@@ -62,7 +62,7 @@
@dataclasses.dataclass
class NestedClass(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers."""
one_union_class: Optional[ClassWithUnion] = None
two_union_class: Optional[ClassWithTwoUnions] = None
@@ -78,7 +78,7 @@
@dataclasses.dataclass
class OneofFieldTest(ir_data.Message):
- """Basic test class for oneof fields"""
+ """Basic test class for oneof fields."""
int_field_1: Optional[int] = ir_data_fields.oneof_field("type_1")
int_field_2: Optional[int] = ir_data_fields.oneof_field("type_1")
@@ -86,7 +86,7 @@
class OneOfTest(unittest.TestCase):
- """Tests for the the various oneof field helpers"""
+ """Tests for the the various oneof field helpers."""
def test_field_attribute(self):
"""Test the `oneof_field` helper."""
@@ -97,48 +97,48 @@
self.assertEqual(test_field.metadata.get("oneof"), "type_1")
def test_init_default(self):
- """Test creating an instance with default fields"""
+ """Test creating an instance with default fields."""
one_of_field_test = OneofFieldTest()
self.assertIsNone(one_of_field_test.int_field_1)
self.assertIsNone(one_of_field_test.int_field_2)
self.assertTrue(one_of_field_test.normal_field)
def test_init(self):
- """Test creating an instance with non-default fields"""
+ """Test creating an instance with non-default fields."""
one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False)
self.assertEqual(one_of_field_test.int_field_1, 10)
self.assertIsNone(one_of_field_test.int_field_2)
self.assertFalse(one_of_field_test.normal_field)
def test_set_oneof_field(self):
- """Tests setting oneof fields causes others in the group to be unset"""
+ """Tests setting oneof fields causes others in the group to be unset."""
one_of_field_test = OneofFieldTest()
one_of_field_test.int_field_1 = 10
self.assertEqual(one_of_field_test.int_field_1, 10)
- self.assertEqual(one_of_field_test.int_field_2, None)
+ self.assertIsNone(one_of_field_test.int_field_2)
one_of_field_test.int_field_2 = 20
- self.assertEqual(one_of_field_test.int_field_1, None)
+ self.assertIsNone(one_of_field_test.int_field_1)
self.assertEqual(one_of_field_test.int_field_2, 20)
# Do it again
one_of_field_test.int_field_1 = 10
self.assertEqual(one_of_field_test.int_field_1, 10)
- self.assertEqual(one_of_field_test.int_field_2, None)
+ self.assertIsNone(one_of_field_test.int_field_2)
one_of_field_test.int_field_2 = 20
- self.assertEqual(one_of_field_test.int_field_1, None)
+ self.assertIsNone(one_of_field_test.int_field_1)
self.assertEqual(one_of_field_test.int_field_2, 20)
- # Now create a new instance and make sure changes to it are not reflected
- # on the original object.
+ # Now create a new instance and make sure changes to it are not
+ # reflected on the original object.
one_of_field_test_2 = OneofFieldTest()
one_of_field_test_2.int_field_1 = 1000
self.assertEqual(one_of_field_test_2.int_field_1, 1000)
- self.assertEqual(one_of_field_test_2.int_field_2, None)
- self.assertEqual(one_of_field_test.int_field_1, None)
+ self.assertIsNone(one_of_field_test_2.int_field_2)
+ self.assertIsNone(one_of_field_test.int_field_1)
self.assertEqual(one_of_field_test.int_field_2, 20)
def test_set_to_none(self):
- """Tests explicitly setting a oneof field to None"""
+ """Tests explicitly setting a oneof field to None."""
one_of_field_test = OneofFieldTest(int_field_1=10, normal_field=False)
self.assertEqual(one_of_field_test.int_field_1, 10)
self.assertIsNone(one_of_field_test.int_field_2)
@@ -163,7 +163,7 @@
self.assertFalse(one_of_field_test.normal_field)
def test_oneof_specs(self):
- """Tests the `oneof_field_specs` filter"""
+ """Tests the `oneof_field_specs` filter."""
expected = {
"int_field_1": ir_data_fields.make_field_spec(
"int_field_1", int, ir_data_fields.FieldContainer.OPTIONAL, "type_1"
@@ -178,7 +178,7 @@
self.assertDictEqual(actual, expected)
def test_oneof_mappings(self):
- """Tests the `oneof_mappings` function"""
+ """Tests the `oneof_mappings` function."""
expected = (("int_field_1", "type_1"), ("int_field_2", "type_1"))
actual = ir_data_fields.IrDataclassSpecs.get_specs(
OneofFieldTest
@@ -187,10 +187,16 @@
class IrDataFieldsTest(unittest.TestCase):
- """Tests misc methods in ir_data_fields"""
+ """Tests misc methods in ir_data_fields."""
+
+ def assertEmpty(self, obj):
+ self.assertEqual(len(obj), 0, msg=f"{obj} is not empty.")
+
+ def assertLen(self, obj, length):
+ self.assertEqual(len(obj), length, msg=f"{obj} has length {len(obj)}.")
def test_copy(self):
- """Tests copying a data class works as expected"""
+ """Tests copying a data class works as expected."""
union = ClassWithTwoUnions(
opaque=Opaque(), boolean=True, non_union_field=10, seq_field=[1, 2, 3]
)
@@ -204,25 +210,25 @@
self.assertIsNone(empty_copy)
def test_copy_values_list(self):
- """Tests that CopyValuesList copies values"""
+ """Tests that CopyValuesList copies values."""
data_list = ir_data_fields.CopyValuesList(ListCopyTestClass)
- self.assertEqual(len(data_list), 0)
+ self.assertEmpty(data_list)
list_test = ListCopyTestClass(non_union_field=2, seq_field=[5, 6, 7])
list_tests = [ir_data_fields.copy(list_test) for _ in range(4)]
data_list.extend(list_tests)
- self.assertEqual(len(data_list), 4)
+ self.assertLen(data_list, 4)
for i in data_list:
self.assertEqual(i, list_test)
def test_list_param_is_copied(self):
- """Test that lists passed to constructors are converted to CopyValuesList"""
+ """Test that lists passed to constructors are converted to CopyValuesList."""
seq_field = [5, 6, 7]
list_test = ListCopyTestClass(non_union_field=2, seq_field=seq_field)
- self.assertEqual(len(list_test.seq_field), len(seq_field))
+ self.assertLen(list_test.seq_field, len(seq_field))
self.assertIsNot(list_test.seq_field, seq_field)
self.assertEqual(list_test.seq_field, seq_field)
- self.assertTrue(isinstance(list_test.seq_field, ir_data_fields.CopyValuesList))
+ self.assertIsInstance(list_test.seq_field, ir_data_fields.CopyValuesList)
def test_copy_oneof(self):
"""Tests copying an IR data class that has oneof fields."""
diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py
index 2d38eac..f880231 100644
--- a/compiler/util/ir_data_utils.py
+++ b/compiler/util/ir_data_utils.py
@@ -91,7 +91,7 @@
class IrDataSerializer:
- """Provides methods for serializing IR data objects"""
+ """Provides methods for serializing IR data objects."""
def __init__(self, ir: MessageT):
assert ir is not None
@@ -102,6 +102,7 @@
ir: MessageT,
field_func: Callable[[MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]]],
) -> MutableMapping[str, Any]:
+ """Translates the IR to a standard Python `dict`."""
assert ir is not None
values: MutableMapping[str, Any] = {}
for spec, value in field_func(ir):
@@ -117,12 +118,12 @@
"""Converts the IR data class to a dictionary."""
def non_empty(ir):
- return fields_and_values(
+ return _fields_and_values(
ir, lambda v: v is not None and (not isinstance(v, list) or len(v))
)
def all_fields(ir):
- return fields_and_values(ir)
+ return _fields_and_values(ir)
# It's tempting to use `dataclasses.asdict` here, but that does a deep
# copy which is overkill for the current usage; mainly as an intermediary
@@ -130,17 +131,17 @@
return self._to_dict(self.ir, non_empty if exclude_none else all_fields)
def to_json(self, *args, **kwargs):
- """Converts the IR data class to a JSON string"""
+ """Converts the IR data class to a JSON string."""
return json.dumps(self.to_dict(exclude_none=True), *args, **kwargs)
@staticmethod
def from_json(data_cls, data):
- """Constructs an IR data class from the given JSON string"""
+ """Constructs an IR data class from the given JSON string."""
as_dict = json.loads(data)
return IrDataSerializer.from_dict(data_cls, as_dict)
def copy_from_dict(self, data):
- """Deserializes the data and overwrites the IR data class with it"""
+ """Deserializes the data and overwrites the IR data class with it."""
cls = type(self.ir)
data_copy = IrDataSerializer.from_dict(cls, data)
for k in field_specs(cls):
@@ -148,6 +149,7 @@
@staticmethod
def _enum_type_converter(enum_cls: type[enum.Enum], val: Any) -> enum.Enum:
+ """Converts `val` to an instance of `enum_cls`."""
if isinstance(val, str):
return getattr(enum_cls, val)
return enum_cls(val)
@@ -158,6 +160,7 @@
@staticmethod
def _from_dict(data_cls: type[MessageT], data):
+ """Translates the given `data` dict to an instance of `data_cls`."""
class_fields: MutableMapping[str, Any] = {}
for name, spec in ir_data_fields.field_specs(data_cls).items():
if (value := data.get(name)) is not None:
@@ -188,12 +191,12 @@
@staticmethod
def from_dict(data_cls: type[MessageT], data):
- """Creates a new IR data instance from a serialized dict"""
+ """Creates a new IR data instance from a serialized dict."""
return IrDataSerializer._from_dict(data_cls, data)
class _IrDataSequenceBuilder(MutableSequence[MessageT]):
- """Wrapper for a list of IR elements
+ """Wrapper for a list of IR elements.
Simply wraps the returned values during indexed access and iteration with
IrDataBuilders.
@@ -236,7 +239,7 @@
class _IrDataBuilder(Generic[MessageT]):
- """Wrapper for an IR element"""
+ """Wrapper for an IR element."""
def __init__(self, ir: MessageT) -> None:
assert ir is not None
@@ -254,10 +257,15 @@
def __getattribute__(self, name: str) -> Any:
"""Hook for `getattr` that handles adding missing fields.
- If the field is missing inserts it, and then returns either the raw value
- for basic types
- or a new IrBuilder wrapping the field to handle the next field access in a
- longer chain.
+ If the field is missing inserts it, and then returns either the raw
+ value for basic types or a new IrBuilder wrapping the field to handle
+ the next field access in a longer chain.
+
+ Arguments:
+ name: the name of the attribute to set/retrieve
+
+ Returns:
+ The value of the attribute `name`.
"""
# Check if getting one of the builder attributes
@@ -294,12 +302,12 @@
return obj
def CopyFrom(self, template: MessageT): # pylint:disable=invalid-name
- """Updates the fields of this class with values set in the template"""
+ """Updates the fields of this class with values set in the template."""
update(cast(type[MessageT], self), template)
def builder(target: MessageT) -> MessageT:
- """Create a wrapper around the target to help build an IR Data structure"""
+ """Create a wrapper around the target to help build an IR Data structure."""
# Check if the target is already a builder.
if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)):
return target
@@ -314,7 +322,7 @@
def _field_checker_from_spec(spec: ir_data_fields.FieldSpec):
- """Helper that builds an FieldChecker that pretends to be an IR class"""
+ """Helper that builds an FieldChecker that pretends to be an IR class."""
if spec.is_sequence:
return []
if spec.is_dataclass:
@@ -322,14 +330,15 @@
return ir_data_fields.build_default(spec)
-def _field_type(ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> type:
+def _field_type(ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> type[Any]:
+ """Returns the Python type of the given field."""
if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
return ir_or_spec.data_type
return type(ir_or_spec)
class _ReadOnlyFieldChecker:
- """Class used the chain calls to fields that aren't set"""
+ """Class used to chain calls to fields that aren't set."""
def __init__(self, ir_or_spec: Union[MessageT, ir_data_fields.FieldSpec]) -> None:
self.ir_or_spec = ir_or_spec
@@ -382,8 +391,7 @@
def reader(obj: Union[MessageT, _ReadOnlyFieldChecker]) -> MessageT:
- """Builds a read-only wrapper that can be used to check chains of possibly
- unset fields.
+ """Builds a wrapper that can be used to read chains of possibly unset fields.
This wrapper explicitly does not alter the wrapped object and is only
intended for reading contents.
@@ -391,18 +399,36 @@
For example, a `reader` lets you do:
```
def get_function_name_end_column(function: ir_data.Function):
- return reader(function).function_name.source_location.end.column
+ return reader(function).function_name.source_location.end.column
```
Instead of:
```
def get_function_name_end_column(function: ir_data.Function):
- if function.function_name:
- if function.function_name.source_location:
- if function.function_name.source_location.end:
- return function.function_name.source_location.end.column
- return 0
+ if function.function_name:
+ if function.function_name.source_location:
+ if function.function_name.source_location.end:
+ return function.function_name.source_location.end.column
+ return 0
```
+
+ Arguments:
+ obj: The IR node to wrap.
+
+ Returns:
+ An object whose attributes return either:
+
+ The value of `obj.attr` if `attr` is an atomic type and is set on
+ `obj`.
+
+ A default value for `obj.attr` if `obj.attr` is not set, but is of an
+ atomic type.
+
+ A read-only wrapper around `obj.attr` if `obj.attr` is set and is an IR
+ node type.
+
+ A read-only wrapper around an empty IR node object if `obj.attr` is not
+ set, and is of an IR node type.
"""
# Create a read-only wrapper if it's not already one.
if not isinstance(obj, _ReadOnlyFieldChecker):
@@ -426,15 +452,19 @@
return cast(ir_data_fields.IrDataclassInstance, ir_or_wrapper)
-def fields_and_values(
+def _fields_and_values(
ir_wrapper: Union[MessageT, _ReadOnlyFieldChecker],
value_filt: Optional[Callable[[Any], bool]] = None,
) -> list[Tuple[ir_data_fields.FieldSpec, Any]]:
"""Retrieves the fields and their values for a given IR data class.
Args:
- ir: The IR data class or a read-only wrapper of an IR data class.
- value_filt: Optional filter used to exclude values.
+ ir: The IR data class or a read-only wrapper of an IR data class.
+ value_filt: Optional filter used to exclude values.
+
+ Returns:
+ Fields and their values for the IR held by `ir_wrapper`, optionally
+ filtered by `value_filt`.
"""
if (ir := _extract_ir(ir_wrapper)) is None:
return []
@@ -443,15 +473,22 @@
def get_set_fields(ir: MessageT):
- """Retrieves the field spec and value of fields that are set in the given IR data class.
+ """Retrieves the field specs and values of fields that are set in `ir`.
A value is considered "set" if it is not None.
+
+ Arguments:
+ ir: The IR node to operate on.
+
+ Returns:
+ The field specs and values of fields that are set in the given IR data
+ class.
"""
- return fields_and_values(ir, lambda v: v is not None)
+ return _fields_and_values(ir, lambda v: v is not None)
def copy(ir_wrapper: Optional[MessageT]) -> Optional[MessageT]:
- """Creates a copy of the given IR data class"""
+ """Creates a copy of the given IR data class."""
if (ir := _extract_ir(ir_wrapper)) is None:
return None
ir_copy = ir_data_fields.copy(ir)
diff --git a/compiler/util/ir_data_utils_test.py b/compiler/util/ir_data_utils_test.py
index 82baac6..954fc33 100644
--- a/compiler/util/ir_data_utils_test.py
+++ b/compiler/util/ir_data_utils_test.py
@@ -35,12 +35,12 @@
@dataclasses.dataclass
class Opaque(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers."""
@dataclasses.dataclass
class ClassWithUnion(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers."""
opaque: Optional[Opaque] = ir_data_fields.oneof_field("type")
integer: Optional[int] = ir_data_fields.oneof_field("type")
@@ -51,7 +51,7 @@
@dataclasses.dataclass
class ClassWithTwoUnions(ir_data.Message):
- """Used for testing data field helpers"""
+ """Used for testing data field helpers."""
opaque: Optional[Opaque] = ir_data_fields.oneof_field("type_1")
integer: Optional[int] = ir_data_fields.oneof_field("type_1")
@@ -65,7 +65,7 @@
"""Tests for the miscellaneous utility functions in ir_data_utils.py."""
def test_field_specs(self):
- """Tests the `field_specs` method"""
+ """Tests the `field_specs` method."""
fields = ir_data_utils.field_specs(ir_data.TypeDefinition)
self.assertIsNotNone(fields)
expected_fields = (
@@ -132,7 +132,7 @@
self.assertEqual(fields["base_type"], expected_field)
def test_is_sequence(self):
- """Tests for the `FieldSpec.is_sequence` helper"""
+ """Tests for the `FieldSpec.is_sequence` helper."""
type_def = ir_data.TypeDefinition(
attribute=[
ir_data.Attribute(
@@ -151,7 +151,7 @@
self.assertFalse(fields["is_default"].is_sequence)
def test_is_dataclass(self):
- """Tests FieldSpec.is_dataclass against ir_data"""
+ """Tests FieldSpec.is_dataclass against ir_data."""
type_def = ir_data.TypeDefinition(
attribute=[
ir_data.Attribute(
@@ -173,7 +173,7 @@
self.assertFalse(fields["fields_in_dependency_order"].is_dataclass)
def test_get_set_fields(self):
- """Tests that get set fields works"""
+ """Tests that get set fields works."""
type_def = ir_data.TypeDefinition(
attribute=[
ir_data.Attribute(
@@ -196,7 +196,7 @@
self.assertSetEqual(found_fields, expected_fields)
def test_copy(self):
- """Tests the `copy` helper"""
+ """Tests the `copy` helper."""
attribute = ir_data.Attribute(
value=ir_data.AttributeValue(expression=ir_data.Expression()),
name=ir_data.Word(text="phil"),
@@ -219,7 +219,7 @@
self.assertIsNot(type_def.attribute, type_def_copy.attribute)
def test_update(self):
- """Tests the `update` helper"""
+ """Tests the `update` helper."""
attribute_template = ir_data.Attribute(
value=ir_data.AttributeValue(expression=ir_data.Expression()),
name=ir_data.Word(text="phil"),
@@ -236,10 +236,16 @@
class IrDataBuilderTest(unittest.TestCase):
- """Tests for IrDataBuilder"""
+ """Tests for IrDataBuilder."""
+
+ def assertEmpty(self, obj):
+ self.assertEqual(len(obj), 0, msg=f"{obj} is not empty.")
+
+ def assertLen(self, obj, length):
+ self.assertEqual(len(obj), length, msg=f"{obj} has length {len(obj)}.")
def test_ir_data_builder(self):
- """Tests that basic builder chains work"""
+ """Tests that basic builder chains work."""
# We start with an empty type
type_def = ir_data.TypeDefinition()
self.assertFalse(type_def.HasField("name"))
@@ -258,7 +264,7 @@
self.assertEqual(type_def.name.name.text, "phil")
def test_ir_data_builder_bad_field(self):
- """Tests accessing an undefined field name fails"""
+ """Tests accessing an undefined field name fails."""
type_def = ir_data.TypeDefinition()
builder = ir_data_utils.builder(type_def)
self.assertRaises(AttributeError, lambda: builder.foo)
@@ -266,11 +272,11 @@
self.assertRaises(AttributeError, getattr, type_def, "foo")
def test_ir_data_builder_sequence(self):
- """Tests that sequences are properly wrapped"""
+ """Tests that sequences are properly wrapped."""
# We start with an empty type
type_def = ir_data.TypeDefinition()
self.assertTrue(type_def.HasField("attribute"))
- self.assertEqual(len(type_def.attribute), 0)
+ self.assertEmpty(type_def.attribute)
# Now setup a builder
builder = ir_data_utils.builder(type_def)
@@ -284,12 +290,12 @@
builder.attribute.append(attribute)
self.assertEqual(builder.attribute, [attribute])
self.assertTrue(type_def.HasField("attribute"))
- self.assertEqual(len(type_def.attribute), 1)
+ self.assertLen(type_def.attribute, 1)
self.assertEqual(type_def.attribute[0], attribute)
# Lets make it longer and then try iterating
builder.attribute.append(attribute)
- self.assertEqual(len(type_def.attribute), 2)
+ self.assertLen(type_def.attribute, 2)
for attr in builder.attribute:
# Modify the attributes
attr.name.text = "bob"
@@ -305,7 +311,7 @@
name=ir_data.Word(text="bob"),
)
- self.assertEqual(len(type_def.attribute), 3)
+ self.assertLen(type_def.attribute, 3)
for attr in type_def.attribute:
self.assertEqual(attr, new_attribute)
@@ -368,7 +374,7 @@
)
def test_ir_data_builder_sequence_scalar(self):
- """Tests that sequences of scalars function properly"""
+ """Tests that sequences of scalars function properly."""
# We start with an empty type
structure = ir_data.Structure()
@@ -380,7 +386,7 @@
builder.fields_in_dependency_order.append(11)
self.assertTrue(structure.HasField("fields_in_dependency_order"))
- self.assertEqual(len(structure.fields_in_dependency_order), 2)
+ self.assertLen(structure.fields_in_dependency_order, 2)
self.assertEqual(structure.fields_in_dependency_order[0], 12)
self.assertEqual(structure.fields_in_dependency_order[1], 11)
self.assertEqual(builder.fields_in_dependency_order, [12, 11])
@@ -404,10 +410,10 @@
class IrDataSerializerTest(unittest.TestCase):
- """Tests for IrDataSerializer"""
+ """Tests for IrDataSerializer."""
def test_ir_data_serializer_to_dict(self):
- """Tests serialization with `IrDataSerializer.to_dict` with default settings"""
+ """Tests serialization with `IrDataSerializer.to_dict` with default settings."""
attribute = ir_data.Attribute(
value=ir_data.AttributeValue(expression=ir_data.Expression()),
name=ir_data.Word(text="phil"),
@@ -438,7 +444,7 @@
self.assertDictEqual(raw_dict, expected)
def test_ir_data_serializer_to_dict_exclude_none(self):
- """Tests serialization with `IrDataSerializer.to_dict` when excluding None values"""
+ """.Tests serialization with `IrDataSerializer.to_dict` when excluding None values"""
attribute = ir_data.Attribute(
value=ir_data.AttributeValue(expression=ir_data.Expression()),
name=ir_data.Word(text="phil"),
@@ -449,7 +455,7 @@
self.assertDictEqual(raw_dict, expected)
def test_ir_data_serializer_to_dict_enum(self):
- """Tests that serialization of `enum.Enum` values works properly"""
+ """Tests that serialization of `enum.Enum` values works properly."""
type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE)
serializer = ir_data_utils.IrDataSerializer(type_def)
raw_dict = serializer.to_dict(exclude_none=True)
@@ -457,7 +463,7 @@
self.assertDictEqual(raw_dict, expected)
def test_ir_data_serializer_from_dict(self):
- """Tests deserializing IR data from a serialized dict"""
+ """Tests deserializing IR data from a serialized dict."""
attribute = ir_data.Attribute(
value=ir_data.AttributeValue(expression=ir_data.Expression()),
name=ir_data.Word(text="phil"),
@@ -468,7 +474,7 @@
self.assertEqual(attribute, new_attribute)
def test_ir_data_serializer_from_dict_enum(self):
- """Tests that deserializing `enum.Enum` values works properly"""
+ """Tests that deserializing `enum.Enum` values works properly."""
type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE)
serializer = ir_data_utils.IrDataSerializer(type_def)
@@ -477,7 +483,7 @@
self.assertEqual(type_def, new_type_def)
def test_ir_data_serializer_from_dict_enum_is_str(self):
- """Tests that deserializing `enum.Enum` values works properly when string constant is used"""
+ """Tests that deserializing `enum.Enum` values works properly when string constant is used."""
type_def = ir_data.TypeDefinition(addressable_unit=ir_data.AddressableUnit.BYTE)
raw_dict = {"addressable_unit": "BYTE"}
serializer = ir_data_utils.IrDataSerializer(type_def)
@@ -485,7 +491,7 @@
self.assertEqual(type_def, new_type_def)
def test_ir_data_serializer_from_dict_exclude_none(self):
- """Tests that deserializing from a dict that excluded None values works properly"""
+ """Tests that deserializing from a dict that excluded None values works properly."""
attribute = ir_data.Attribute(
value=ir_data.AttributeValue(expression=ir_data.Expression()),
name=ir_data.Word(text="phil"),
@@ -552,7 +558,7 @@
self.assertIsNotNone(func)
def test_ir_data_serializer_copy_from_dict(self):
- """Tests that updating an IR data struct from a dict works properly"""
+ """Tests that updating an IR data struct from a dict works properly."""
attribute = ir_data.Attribute(
value=ir_data.AttributeValue(expression=ir_data.Expression()),
name=ir_data.Word(text="phil"),
@@ -567,10 +573,10 @@
class ReadOnlyFieldCheckerTest(unittest.TestCase):
- """Tests the ReadOnlyFieldChecker"""
+ """Tests the ReadOnlyFieldChecker."""
def test_basic_wrapper(self):
- """Tests basic field checker actions"""
+ """Tests basic field checker actions."""
union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10)
field_checker = ir_data_utils.reader(union)
@@ -591,7 +597,7 @@
self.assertTrue(field_checker.HasField("non_union_field"))
def test_construct_from_field_checker(self):
- """Tests that constructing from another field checker works"""
+ """Tests that constructing from another field checker works."""
union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10)
field_checker_orig = ir_data_utils.reader(union)
field_checker = ir_data_utils.reader(field_checker_orig)
@@ -615,7 +621,7 @@
self.assertTrue(field_checker.HasField("non_union_field"))
def test_read_only(self) -> None:
- """Tests that the read only wrapper really is read only"""
+ """Tests that the read only wrapper really is read only."""
union = ClassWithTwoUnions(opaque=Opaque(), boolean=True, non_union_field=10)
field_checker = ir_data_utils.reader(union)
diff --git a/compiler/util/name_conversion.py b/compiler/util/name_conversion.py
index b13d667..b4303fb 100644
--- a/compiler/util/name_conversion.py
+++ b/compiler/util/name_conversion.py
@@ -63,10 +63,19 @@
def convert_case(case_from, case_to, value):
"""Converts cases based on runtime case values.
- Note: Cases can be strings or enum values."""
+ Note: Cases can be strings or enum values.
+
+ Arguments:
+ case_from: the name of the original case
+ case_to: the name of the desired case
+ value: the value to convert
+
+ Returns:
+ `value` converted from `case_from` to `case_to`.
+ """
return _case_conversions[case_from, case_to](value)
def is_case_conversion_supported(case_from, case_to):
- """Determines if a case conversion would be supported"""
+ """Determines if a case conversion would be supported."""
return (case_from, case_to) in _case_conversions
diff --git a/runtime/cpp/emboss_prelude.h b/runtime/cpp/emboss_prelude.h
index 8b5fa28..ad302cd 100644
--- a/runtime/cpp/emboss_prelude.h
+++ b/runtime/cpp/emboss_prelude.h
@@ -267,7 +267,7 @@
static_cast</**/ ::std::uint64_t>(value) <=
((static_cast<ValueType>(1) << (Parameters::kBits - 1)) << 1) -
1 &&
- Parameters::ValueIsOk(value);
+ Parameters::ValueIsOk(static_cast<ValueType>(value));
}
void UncheckedWrite(ValueType value) const {
buffer_.UncheckedWriteUInt(value);
@@ -429,7 +429,8 @@
: ((static_cast<ValueType>(1) << (Parameters::kBits - 2)) -
1) * 2 +
1) &&
- Parameters::ValueIsOk(value);
+ Parameters::ValueIsOk(static_cast<ValueType>(value));
+
}
void UncheckedWrite(ValueType value) const {