Consolidate `cache_data_fields`
diff --git a/compiler/util/ir_data.py b/compiler/util/ir_data.py
index 9e19da4..55f12a7 100644
--- a/compiler/util/ir_data.py
+++ b/compiler/util/ir_data.py
@@ -75,17 +75,6 @@
     return None
 
 
-def _cache_message_specs():
-  # This needs to be done after the dataclass decorators run and create the wrapped classes.
-  for data_class in ir_data_fields.all_ir_classes(
-      sys.modules[Message.__module__]
-  ):
-    if data_class is not Message:
-      data_class.field_specs = ir_data_fields.IrDataclassSpecs.get_specs(
-          data_class
-      )
-
-
 ################################################################################
 # From here to the end of the file are actual structure definitions.
 
@@ -875,4 +864,4 @@
 
 
 # Post-process the dataclasses to add cached fields.
-_cache_message_specs()
+ir_data_fields.cache_message_specs(sys.modules[Message.__module__], Message)
diff --git a/compiler/util/ir_data_fields.py b/compiler/util/ir_data_fields.py
index fe9ef75..f7991c0 100644
--- a/compiler/util/ir_data_fields.py
+++ b/compiler/util/ir_data_fields.py
@@ -219,6 +219,20 @@
     return cls.spec_cache[data_class]
 
 
+def cache_message_specs(mod, cls):
+  """Adds a cached `field_specs` attribute to IR dataclasses in `mod`
+  excluding the given base `cls`.
+
+  This needs to be done after the dataclass decorators run and create the
+  wrapped classes.
+  """
+  for data_class in all_ir_classes(mod):
+    if data_class is not cls:
+      data_class.field_specs = IrDataclassSpecs.get_specs(
+          data_class
+      )
+
+
 def _field_specs(cls: type[T]) -> Mapping[str, FieldSpec]:
   """Gets the IR data field names and types for the given IR data class"""
   # Get the dataclass fields
diff --git a/compiler/util/ir_data_fields_test.py b/compiler/util/ir_data_fields_test.py
index 9e27536..fa99943 100644
--- a/compiler/util/ir_data_fields_test.py
+++ b/compiler/util/ir_data_fields_test.py
@@ -85,18 +85,6 @@
   normal_field: bool = True
 
 
-def _cache_message_specs():
-  # This needs to be done after the dataclass decorators run and create the
-  # wrapped classes.
-  for data_class in ir_data_fields.all_ir_classes(
-      sys.modules[OneofFieldTest.__module__]
-  ):
-    if data_class is not ir_data.Message:
-      data_class.field_specs = ir_data_fields.IrDataclassSpecs.get_specs(
-          data_class
-      )
-
-
 class OneOfTest(unittest.TestCase):
   """Tests for the the various oneof field helpers"""
 
@@ -259,7 +247,8 @@
     self.assertEqual(oneof_test.normal_field, False)
 
 
-_cache_message_specs()
+ir_data_fields.cache_message_specs(
+  sys.modules[OneofFieldTest.__module__], ir_data.Message)
 
 if __name__ == "__main__":
   unittest.main()
diff --git a/compiler/util/ir_data_utils_test.py b/compiler/util/ir_data_utils_test.py
index 951444a..2ea873e 100644
--- a/compiler/util/ir_data_utils_test.py
+++ b/compiler/util/ir_data_utils_test.py
@@ -644,19 +644,8 @@
     self.assertRaises(AttributeError, set_field)
 
 
-def _cache_message_specs():
-  # This needs to be done after the dataclass decorators run and create the
-  # wrapped classes.
-  for data_class in ir_data_fields.all_ir_classes(
-      sys.modules[ReadOnlyFieldCheckerTest.__module__]
-  ):
-    if data_class is not ir_data.Message:
-      data_class.field_specs = ir_data_fields.IrDataclassSpecs.get_specs(
-          data_class
-      )
-
-
-_cache_message_specs()
+ir_data_fields.cache_message_specs(
+  sys.modules[ReadOnlyFieldCheckerTest.__module__], ir_data.Message)
 
 if __name__ == "__main__":
   unittest.main()