Merge pull request #142 from EricRahm/add_ir_data_serializer

Split out an `IrDataSerializer` class
diff --git a/compiler/back_end/cpp/emboss_codegen_cpp.py b/compiler/back_end/cpp/emboss_codegen_cpp.py
index 0a70f41..6da9f7d 100644
--- a/compiler/back_end/cpp/emboss_codegen_cpp.py
+++ b/compiler/back_end/cpp/emboss_codegen_cpp.py
@@ -27,6 +27,7 @@
 from compiler.back_end.cpp import header_generator
 from compiler.util import error
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 
 
 def _parse_command_line(argv):
@@ -82,9 +83,9 @@
 def main(flags):
   if flags.input_file:
     with open(flags.input_file) as f:
-      ir = ir_data.EmbossIr.from_json(f.read())
+      ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, f.read())
   else:
-    ir = ir_data.EmbossIr.from_json(sys.stdin.read())
+    ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, sys.stdin.read())
   config = header_generator.Config(include_enum_traits=flags.cc_enum_traits)
   header, errors = generate_headers_and_log_errors(ir, flags.color_output, config)
   if errors:
diff --git a/compiler/front_end/emboss_front_end.py b/compiler/front_end/emboss_front_end.py
index 6232a40..1e30ded 100644
--- a/compiler/front_end/emboss_front_end.py
+++ b/compiler/front_end/emboss_front_end.py
@@ -30,6 +30,7 @@
 from compiler.front_end import glue
 from compiler.front_end import module_ir
 from compiler.util import error
+from compiler.util import ir_data_utils
 
 
 def _parse_command_line(argv):
@@ -178,10 +179,10 @@
     print(glue.format_production_set(
         set(module_ir.PRODUCTIONS) - main_module_debug_info.used_productions))
   if flags.output_ir_to_stdout:
-    print(ir.to_json())
+    print(ir_data_utils.IrDataSerializer(ir).to_json())
   if flags.output_file:
     with open(flags.output_file, "w") as f:
-      f.write(ir.to_json())
+      f.write(ir_data_utils.IrDataSerializer(ir).to_json())
   return 0
 
 
diff --git a/compiler/front_end/glue.py b/compiler/front_end/glue.py
index a1b0706..7724da9 100644
--- a/compiler/front_end/glue.py
+++ b/compiler/front_end/glue.py
@@ -34,6 +34,7 @@
 from compiler.front_end import write_inference
 from compiler.util import error
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import parser_types
 from compiler.util import resources
 
@@ -111,7 +112,7 @@
 
   def format_module_ir(self):
     """Renders self.ir in a human-readable format."""
-    return self.ir.to_json(indent=2)
+    return ir_data_utils.IrDataSerializer(self.ir).to_json(indent=2)
 
 
 def format_production_set(productions):
diff --git a/compiler/front_end/glue_test.py b/compiler/front_end/glue_test.py
index a2b61ad..10613d7 100644
--- a/compiler/front_end/glue_test.py
+++ b/compiler/front_end/glue_test.py
@@ -20,6 +20,7 @@
 from compiler.front_end import glue
 from compiler.util import error
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import parser_types
 from compiler.util import test_util
 
@@ -33,7 +34,7 @@
     _ROOT_PACKAGE, _SPAN_SE_LOG_FILE_PATH).decode(encoding="UTF-8")
 _SPAN_SE_LOG_FILE_READER = test_util.dict_file_reader(
     {_SPAN_SE_LOG_FILE_PATH: _SPAN_SE_LOG_FILE_EMB})
-_SPAN_SE_LOG_FILE_IR = ir_data.Module.from_json(
+_SPAN_SE_LOG_FILE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.Module,
     pkgutil.get_data(
         _ROOT_PACKAGE,
         _GOLDEN_PATH + "span_se_log_file_status.ir.txt"
@@ -155,7 +156,7 @@
     self.assertEqual(_SPAN_SE_LOG_FILE_PARSE_TREE_TEXT.strip(),
                      debug_info.format_parse_tree().strip())
     self.assertEqual(_SPAN_SE_LOG_FILE_IR, debug_info.ir)
-    self.assertEqual(_SPAN_SE_LOG_FILE_IR.to_json(indent=2),
+    self.assertEqual(ir_data_utils.IrDataSerializer(_SPAN_SE_LOG_FILE_IR).to_json(indent=2),
                      debug_info.format_module_ir())
 
   def test_parse_emboss_file(self):
diff --git a/compiler/front_end/module_ir_test.py b/compiler/front_end/module_ir_test.py
index 1f4233d..57d5f4c 100644
--- a/compiler/front_end/module_ir_test.py
+++ b/compiler/front_end/module_ir_test.py
@@ -24,6 +24,7 @@
 from compiler.front_end import parser
 from compiler.front_end import tokenizer
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import test_util
 
 _TESTDATA_PATH = "testdata.golden"
@@ -31,7 +32,7 @@
         _TESTDATA_PATH, "span_se_log_file_status.emb").decode(encoding="UTF-8")
 _MINIMAL_SAMPLE = parser.parse_module(
     tokenizer.tokenize(_MINIMAL_SOURCE, "")[0]).parse_tree
-_MINIMAL_SAMPLE_IR = ir_data.Module.from_json(
+_MINIMAL_SAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.Module,
     pkgutil.get_data(_TESTDATA_PATH, "span_se_log_file_status.ir.txt").decode(
         encoding="UTF-8")
 )
@@ -3978,7 +3979,7 @@
     name, emb, ir_text = case.split("---")
     name = name.strip()
     try:
-      ir = ir_data.Module.from_json(ir_text)
+      ir = ir_data_utils.IrDataSerializer.from_json(ir_data.Module, ir_text)
     except Exception:
       print(name)
       raise
@@ -4152,10 +4153,11 @@
     def test_case(self):
       ir = module_ir.build_ir(test.parse_tree)
       is_superset, error_message = test_util.proto_is_superset(ir, test.ir)
+
       self.assertTrue(
           is_superset,
-          error_message + "\n" + ir.to_json(indent=2) + "\n" +
-          test.ir.to_json(indent=2))
+          error_message + "\n" + ir_data_utils.IrDataSerializer(ir).to_json(indent=2) + "\n" +
+          ir_data_utils.IrDataSerializer(test.ir).to_json(indent=2))
 
     return test_case
 
diff --git a/compiler/front_end/type_check_test.py b/compiler/front_end/type_check_test.py
index 6906738..d308fed 100644
--- a/compiler/front_end/type_check_test.py
+++ b/compiler/front_end/type_check_test.py
@@ -18,6 +18,7 @@
 from compiler.front_end import glue
 from compiler.front_end import type_check
 from compiler.util import error
+from compiler.util import ir_data_utils
 from compiler.util import test_util
 
 
@@ -44,7 +45,7 @@
                        "  0 [+1]     UInt      x\n"
                        "  1 [+true]  UInt:8[]  y\n")
     self.assertEqual([], error.filter_errors(type_check.annotate_types(ir)),
-                     ir.to_json(indent=2))
+                     ir_data_utils.IrDataSerializer(ir).to_json(indent=2))
     expression = ir.module[0].type[0].structure.field[1].location.size
     self.assertEqual(expression.type.WhichOneof("type"), "boolean")
 
diff --git a/compiler/util/BUILD b/compiler/util/BUILD
index bbc2ec0..5946dcb 100644
--- a/compiler/util/BUILD
+++ b/compiler/util/BUILD
@@ -25,6 +25,7 @@
     name = "ir_data",
     srcs = [
         "ir_data.py",
+        "ir_data_utils.py",
     ],
 )
 
diff --git a/compiler/util/ir_data_utils.py b/compiler/util/ir_data_utils.py
new file mode 100644
index 0000000..63b55b7
--- /dev/null
+++ b/compiler/util/ir_data_utils.py
@@ -0,0 +1,31 @@
+# Copyright 2024 Google LLC
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     https://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from compiler.util import ir_data
+
+class IrDataSerializer:
+  """Provides methods for serializing IR data objects"""
+
+  def __init__(self, ir: ir_data.Message):
+    assert ir is not None
+    self.ir = ir
+
+  def to_json(self, *args, **kwargs):
+    """Converts the IR data class to a JSON string"""
+    return self.ir.to_json(*args, **kwargs)
+
+  @staticmethod
+  def from_json(data_cls: type[ir_data.Message], data):
+    """Constructs an IR data class from the given JSON string"""
+    return data_cls.from_json(data)
diff --git a/compiler/util/ir_util_test.py b/compiler/util/ir_util_test.py
index 1afed9c..b92ffb9 100644
--- a/compiler/util/ir_util_test.py
+++ b/compiler/util/ir_util_test.py
@@ -17,6 +17,7 @@
 import unittest
 from compiler.util import expression_parser
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import ir_util
 
 
@@ -410,7 +411,7 @@
                       "bob")
 
   def test_find_object(self):
-    ir = ir_data.EmbossIr.from_json(
+    ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr,
         """{
           "module": [
             {
@@ -564,7 +565,7 @@
                                                 object_path=["Foo", "Bar"]))))
 
   def test_get_base_type(self):
-    array_type_ir = ir_data.Type.from_json(
+    array_type_ir = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
         """{
           "array_type": {
             "element_count": { "constant": { "value": "20" } },
@@ -590,7 +591,7 @@
     self.assertEqual(base_type_ir, ir_util.get_base_type(base_type_ir))
 
   def test_size_of_type_in_bits(self):
-    ir = ir_data.EmbossIr.from_json(
+    ir = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr,
         """{
           "module": [{
             "type": [{
@@ -638,7 +639,7 @@
           }]
         }""")
 
-    fixed_size_type = ir_data.Type.from_json(
+    fixed_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
         """{
           "atomic_type": {
             "reference": {
@@ -648,7 +649,7 @@
         }""")
     self.assertEqual(8, ir_util.fixed_size_of_type_in_bits(fixed_size_type, ir))
 
-    explicit_size_type = ir_data.Type.from_json(
+    explicit_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
         """{
           "atomic_type": {
             "reference": {
@@ -665,7 +666,7 @@
     self.assertEqual(32,
                      ir_util.fixed_size_of_type_in_bits(explicit_size_type, ir))
 
-    fixed_size_array = ir_data.Type.from_json(
+    fixed_size_array = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
         """{
           "array_type": {
             "base_type": {
@@ -686,7 +687,7 @@
     self.assertEqual(40,
                      ir_util.fixed_size_of_type_in_bits(fixed_size_array, ir))
 
-    fixed_size_2d_array = ir_data.Type.from_json(
+    fixed_size_2d_array = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
         """{
           "array_type": {
             "base_type": {
@@ -720,7 +721,7 @@
     self.assertEqual(
         80, ir_util.fixed_size_of_type_in_bits(fixed_size_2d_array, ir))
 
-    automatic_size_array = ir_data.Type.from_json(
+    automatic_size_array = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
         """{
           "array_type": {
             "base_type": {
@@ -749,7 +750,7 @@
     self.assertIsNone(
         ir_util.fixed_size_of_type_in_bits(automatic_size_array, ir))
 
-    variable_size_type = ir_data.Type.from_json(
+    variable_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
         """{
           "atomic_type": {
             "reference": {
@@ -760,7 +761,7 @@
     self.assertIsNone(
         ir_util.fixed_size_of_type_in_bits(variable_size_type, ir))
 
-    no_size_type = ir_data.Type.from_json(
+    no_size_type = ir_data_utils.IrDataSerializer.from_json(ir_data.Type,
         """{
           "atomic_type": {
             "reference": {
diff --git a/compiler/util/traverse_ir_test.py b/compiler/util/traverse_ir_test.py
index 64da8f6..ff54d63 100644
--- a/compiler/util/traverse_ir_test.py
+++ b/compiler/util/traverse_ir_test.py
@@ -19,9 +19,10 @@
 import unittest
 
 from compiler.util import ir_data
+from compiler.util import ir_data_utils
 from compiler.util import traverse_ir
 
-_EXAMPLE_IR = ir_data.EmbossIr.from_json("""{
+_EXAMPLE_IR = ir_data_utils.IrDataSerializer.from_json(ir_data.EmbossIr, """{
 "module": [
   {
     "type": [