Use MessageT in ir_data_utils.py
diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py
index d2f0536..fb8c01c 100644
--- a/compiler/util/ir_data_utils.py
+++ b/compiler/util/ir_data_utils.py
@@ -80,10 +80,10 @@
from compiler.util import ir_data_fields
-T = TypeVar("T", bound=ir_data.Message)
+MessageT = TypeVar("MessaageT", bound=ir_data.Message)
-def field_specs(ir: T | type[T]):
+def field_specs(ir: MessageT | type[MessageT]):
"""Retrieves the field specs for the IR data class"""
data_type = ir if isinstance(ir, type) else type(ir)
return ir_data_fields.IrDataclassSpecs.get_specs(data_type).all_field_specs
@@ -92,14 +92,14 @@
class IrDataSerializer:
"""Provides methods for serializing IR data objects"""
- def __init__(self, ir: T):
+ def __init__(self, ir: MessageT):
assert ir is not None
self.ir = ir
def _to_dict(
self,
- ir: T,
- field_func: Callable[[T], list[Tuple[ir_data_fields.FieldSpec, Any]]],
+ ir: MessageT,
+ field_func: Callable[[MessageT], list[Tuple[ir_data_fields.FieldSpec, Any]]],
) -> MutableMapping[str, Any]:
assert ir is not None
values: MutableMapping[str, Any] = {}
@@ -156,7 +156,7 @@
return lambda val: IrDataSerializer._enum_type_converter(enum_cls, val)
@staticmethod
- def _from_dict(data_cls: type[T], data):
+ def _from_dict(data_cls: type[MessageT], data):
class_fields: MutableMapping[str, Any] = {}
for name, spec in ir_data_fields.field_specs(data_cls).items():
if (value := data.get(name)) is not None:
@@ -185,19 +185,19 @@
return data_cls(**class_fields)
@staticmethod
- def from_dict(data_cls: type[T], data):
+ def from_dict(data_cls: type[MessageT], data):
"""Creates a new IR data instance from a serialized dict"""
return IrDataSerializer._from_dict(data_cls, data)
-class _IrDataSequenceBuilder(MutableSequence[T]):
+class _IrDataSequenceBuilder(MutableSequence[MessageT]):
"""Wrapper for a list of IR elements
Simply wraps the returned values during indexed access and iteration with
IrDataBuilders.
"""
- def __init__(self, target: MutableSequence[T]):
+ def __init__(self, target: MutableSequence[MessageT]):
self._target = target
def __delitem__(self, key):
@@ -233,12 +233,12 @@
self._target.extend(values)
-class _IrDataBuilder(Generic[T]):
+class _IrDataBuilder(Generic[MessageT]):
"""Wrapper for an IR element"""
- def __init__(self, ir: T) -> None:
+ def __init__(self, ir: MessageT) -> None:
assert ir is not None
- self.ir: T = ir
+ self.ir: MessageT = ir
def __setattr__(self, __name: str, __value: Any) -> None:
if __name == "ir":
@@ -246,7 +246,7 @@
object.__setattr__(self, __name, __value)
else:
# Passthrough to the proxy object
- ir: T = object.__getattribute__(self, "ir")
+ ir: MessageT = object.__getattribute__(self, "ir")
setattr(ir, __name, __value)
def __getattribute__(self, name: str) -> Any:
@@ -263,7 +263,7 @@
return object.__getattribute__(self, name)
# Get our target object by bypassing our getattr hook
- ir: T = object.__getattribute__(self, "ir")
+ ir: MessageT = object.__getattribute__(self, "ir")
if ir is None:
return object.__getattribute__(self, name)
@@ -291,12 +291,12 @@
return obj
- def CopyFrom(self, template: T): # pylint:disable=invalid-name
+ def CopyFrom(self, template: MessageT): # pylint:disable=invalid-name
"""Updates the fields of this class with values set in the template"""
- update(cast(type[T], self), template)
+ update(cast(type[MessageT], self), template)
-def builder(target: T) -> T:
+def builder(target: MessageT) -> MessageT:
"""Create a wrapper around the target to help build an IR Data structure"""
# Check if the target is already a builder.
if isinstance(target, (_IrDataBuilder, _IrDataSequenceBuilder)):
@@ -308,7 +308,7 @@
# Create a builder and cast it to the target type to expose type hinting for
# the wrapped type.
- return cast(T, _IrDataBuilder(target))
+ return cast(MessageT, _IrDataBuilder(target))
def _field_checker_from_spec(spec: ir_data_fields.FieldSpec):
@@ -320,7 +320,7 @@
return ir_data_fields.build_default(spec)
-def _field_type(ir_or_spec: T | ir_data_fields.FieldSpec) -> type:
+def _field_type(ir_or_spec: MessageT | ir_data_fields.FieldSpec) -> type:
if isinstance(ir_or_spec, ir_data_fields.FieldSpec):
return ir_or_spec.data_type
return type(ir_or_spec)
@@ -329,7 +329,7 @@
class _ReadOnlyFieldChecker:
"""Class used the chain calls to fields that aren't set"""
- def __init__(self, ir_or_spec: T | ir_data_fields.FieldSpec) -> None:
+ def __init__(self, ir_or_spec: MessageT | ir_data_fields.FieldSpec) -> None:
self.ir_or_spec = ir_or_spec
def __setattr__(self, name: str, value: Any) -> None:
@@ -377,7 +377,7 @@
return not self == other
-def reader(obj: T | _ReadOnlyFieldChecker) -> T:
+def reader(obj: MessageT | _ReadOnlyFieldChecker) -> MessageT:
"""Builds a read-only wrapper that can be used to check chains of possibly
unset fields.
@@ -406,11 +406,11 @@
obj = _ReadOnlyFieldChecker(obj)
# Cast it back to the original type.
- return cast(T, obj)
+ return cast(MessageT, obj)
def _extract_ir(
- ir_or_wrapper: T | _ReadOnlyFieldChecker | _IrDataBuilder | None,
+ ir_or_wrapper: MessageT | _ReadOnlyFieldChecker | _IrDataBuilder | None,
) -> ir_data_fields.IrDataclassInstance | None:
if isinstance(ir_or_wrapper, _ReadOnlyFieldChecker):
ir_or_spec = ir_or_wrapper.ir_or_spec
@@ -424,7 +424,7 @@
def fields_and_values(
- ir_wrapper: T | _ReadOnlyFieldChecker,
+ ir_wrapper: MessageT | _ReadOnlyFieldChecker,
value_filt: Optional[Callable[[Any], bool]] = None,
) -> list[Tuple[ir_data_fields.FieldSpec, Any]]:
"""Retrieves the fields and their values for a given IR data class.
@@ -439,7 +439,7 @@
return ir_data_fields.fields_and_values(ir, value_filt)
-def get_set_fields(ir: T):
+def get_set_fields(ir: MessageT):
"""Retrieves the field spec and value of fields that are set in the given IR data class.
A value is considered "set" if it is not None.
@@ -447,15 +447,15 @@
return fields_and_values(ir, lambda v: v is not None)
-def copy(ir_wrapper: T | None) -> T | None:
+def copy(ir_wrapper: MessageT | None) -> MessageT | None:
"""Creates a copy of the given IR data class"""
if (ir := _extract_ir(ir_wrapper)) is None:
return None
ir_copy = ir_data_fields.copy(ir)
- return cast(T, ir_copy)
+ return cast(MessageT, ir_copy)
-def update(ir: T, template: T):
+def update(ir: MessageT, template: MessageT):
"""Updates `ir`s fields with all set fields in the template."""
if not (template_ir := _extract_ir(template)):
return