Cleanup ir_data_fields
diff --git a/compiler/util/ir_data_fields.py b/compiler/util/ir_data_fields.py
index 8db01f2..fe9ef75 100644
--- a/compiler/util/ir_data_fields.py
+++ b/compiler/util/ir_data_fields.py
@@ -107,7 +107,7 @@
 
 
 class TemporaryCopyValuesList(NamedTuple):
-  """Class used to temporarily hold a CopyValueList while copying and
+  """Class used to temporarily hold a CopyValuesList while copying and
 
   constructing an IR dataclass.
   """
@@ -197,15 +197,6 @@
       if isinstance(type, v.__class__) and _is_ir_dataclass(v)
   )
 
-
-def _all_field_specs(mod):
-  specs = {
-      ir_class: FilteredIrFieldSpecs(_field_specs(ir_class))
-      for ir_class in all_ir_classes(mod)
-  }
-  return specs
-
-
 class IrDataclassSpecs:
   """Maintains a cache of all IR dataclass specs."""
 
@@ -214,8 +205,10 @@
   @classmethod
   def get_mod_specs(cls, mod):
     """Gets the IR dataclass specs for the given module."""
-    mod_specs = _all_field_specs(mod)
-    return mod_specs
+    return {
+      ir_class: FilteredIrFieldSpecs(_field_specs(ir_class))
+      for ir_class in all_ir_classes(mod)
+    }
 
   @classmethod
   def get_specs(cls, data_class):
@@ -255,17 +248,19 @@
     # Get the wrapped types if there are any
     args = get_args(type_hint)
     if origin is not None:
-      # Extract the type
+      # Extract the type.
+      type_hint = args[0]
+
+      # Underneath the hood `typing.Optional` is just a `Union[T, None]` so we
+      # have to check if it's a `Union` instead of just using `Optional`.
       if origin == Union:
-        # Extract Optionals -> Union[T, None]
+        # Make sure this is an `Optional` and not another `Union` type.
         assert len(args) == 2 and args[1] == type(None)
         container_type = FieldContainer.OPTIONAL
-        type_hint = args[0]
-        origin = get_origin(type_hint)
-
-      if origin == list:
+      elif origin == list:
         container_type = FieldContainer.LIST
-        type_hint = get_args(type_hint)[0]
+      else:
+        raise TypeError(f"Field has invalid container type: {origin}")
 
     # Resolve any forward references.
     if isinstance(type_hint, str):