Consolidate `cache_data_fields`
diff --git a/compiler/util/ir_data.py b/compiler/util/ir_data.py
index 9e19da4..55f12a7 100644
--- a/compiler/util/ir_data.py
+++ b/compiler/util/ir_data.py
@@ -75,17 +75,6 @@
return None
-def _cache_message_specs():
- # This needs to be done after the dataclass decorators run and create the wrapped classes.
- for data_class in ir_data_fields.all_ir_classes(
- sys.modules[Message.__module__]
- ):
- if data_class is not Message:
- data_class.field_specs = ir_data_fields.IrDataclassSpecs.get_specs(
- data_class
- )
-
-
################################################################################
# From here to the end of the file are actual structure definitions.
@@ -875,4 +864,4 @@
# Post-process the dataclasses to add cached fields.
-_cache_message_specs()
+ir_data_fields.cache_message_specs(sys.modules[Message.__module__], Message)
diff --git a/compiler/util/ir_data_fields.py b/compiler/util/ir_data_fields.py
index fe9ef75..f7991c0 100644
--- a/compiler/util/ir_data_fields.py
+++ b/compiler/util/ir_data_fields.py
@@ -219,6 +219,20 @@
return cls.spec_cache[data_class]
+def cache_message_specs(mod, cls):
+ """Adds a cached `field_specs` attribute to IR dataclasses in `mod`
+ excluding the given base `cls`.
+
+ This needs to be done after the dataclass decorators run and create the
+ wrapped classes.
+ """
+ for data_class in all_ir_classes(mod):
+ if data_class is not cls:
+ data_class.field_specs = IrDataclassSpecs.get_specs(
+ data_class
+ )
+
+
def _field_specs(cls: type[T]) -> Mapping[str, FieldSpec]:
"""Gets the IR data field names and types for the given IR data class"""
# Get the dataclass fields
diff --git a/compiler/util/ir_data_fields_test.py b/compiler/util/ir_data_fields_test.py
index 9e27536..fa99943 100644
--- a/compiler/util/ir_data_fields_test.py
+++ b/compiler/util/ir_data_fields_test.py
@@ -85,18 +85,6 @@
normal_field: bool = True
-def _cache_message_specs():
- # This needs to be done after the dataclass decorators run and create the
- # wrapped classes.
- for data_class in ir_data_fields.all_ir_classes(
- sys.modules[OneofFieldTest.__module__]
- ):
- if data_class is not ir_data.Message:
- data_class.field_specs = ir_data_fields.IrDataclassSpecs.get_specs(
- data_class
- )
-
-
class OneOfTest(unittest.TestCase):
"""Tests for the the various oneof field helpers"""
@@ -259,7 +247,8 @@
self.assertEqual(oneof_test.normal_field, False)
-_cache_message_specs()
+ir_data_fields.cache_message_specs(
+ sys.modules[OneofFieldTest.__module__], ir_data.Message)
if __name__ == "__main__":
unittest.main()
diff --git a/compiler/util/ir_data_utils_test.py b/compiler/util/ir_data_utils_test.py
index 951444a..2ea873e 100644
--- a/compiler/util/ir_data_utils_test.py
+++ b/compiler/util/ir_data_utils_test.py
@@ -644,19 +644,8 @@
self.assertRaises(AttributeError, set_field)
-def _cache_message_specs():
- # This needs to be done after the dataclass decorators run and create the
- # wrapped classes.
- for data_class in ir_data_fields.all_ir_classes(
- sys.modules[ReadOnlyFieldCheckerTest.__module__]
- ):
- if data_class is not ir_data.Message:
- data_class.field_specs = ir_data_fields.IrDataclassSpecs.get_specs(
- data_class
- )
-
-
-_cache_message_specs()
+ir_data_fields.cache_message_specs(
+ sys.modules[ReadOnlyFieldCheckerTest.__module__], ir_data.Message)
if __name__ == "__main__":
unittest.main()