Cleanup ir_data_fields
diff --git a/compiler/util/ir_data_fields.py b/compiler/util/ir_data_fields.py
index 8db01f2..fe9ef75 100644
--- a/compiler/util/ir_data_fields.py
+++ b/compiler/util/ir_data_fields.py
@@ -107,7 +107,7 @@
class TemporaryCopyValuesList(NamedTuple):
- """Class used to temporarily hold a CopyValueList while copying and
+ """Class used to temporarily hold a CopyValuesList while copying and
constructing an IR dataclass.
"""
@@ -197,15 +197,6 @@
if isinstance(type, v.__class__) and _is_ir_dataclass(v)
)
-
-def _all_field_specs(mod):
- specs = {
- ir_class: FilteredIrFieldSpecs(_field_specs(ir_class))
- for ir_class in all_ir_classes(mod)
- }
- return specs
-
-
class IrDataclassSpecs:
"""Maintains a cache of all IR dataclass specs."""
@@ -214,8 +205,10 @@
@classmethod
def get_mod_specs(cls, mod):
"""Gets the IR dataclass specs for the given module."""
- mod_specs = _all_field_specs(mod)
- return mod_specs
+ return {
+ ir_class: FilteredIrFieldSpecs(_field_specs(ir_class))
+ for ir_class in all_ir_classes(mod)
+ }
@classmethod
def get_specs(cls, data_class):
@@ -255,17 +248,19 @@
# Get the wrapped types if there are any
args = get_args(type_hint)
if origin is not None:
- # Extract the type
+ # Extract the type.
+ type_hint = args[0]
+
+ # Underneath the hood `typing.Optional` is just a `Union[T, None]` so we
+ # have to check if it's a `Union` instead of just using `Optional`.
if origin == Union:
- # Extract Optionals -> Union[T, None]
+ # Make sure this is an `Optional` and not another `Union` type.
assert len(args) == 2 and args[1] == type(None)
container_type = FieldContainer.OPTIONAL
- type_hint = args[0]
- origin = get_origin(type_hint)
-
- if origin == list:
+ elif origin == list:
container_type = FieldContainer.LIST
- type_hint = get_args(type_hint)[0]
+ else:
+ raise TypeError(f"Field has invalid container type: {origin}")
# Resolve any forward references.
if isinstance(type_hint, str):