pw_bloat: Generator for easy access to DataSourceMap structure

Change-Id: Iad4794cd7af341c44c9ef320e4999cbf6779c53c
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/103921
Commit-Queue: Brandon Vu <brandonvu@google.com>
Reviewed-by: Alexei Frolov <frolv@google.com>
Pigweed-Auto-Submit: Brandon Vu <brandonvu@google.com>
diff --git a/pw_bloat/py/label_test.py b/pw_bloat/py/label_test.py
index 33ff724..b96d9d1 100644
--- a/pw_bloat/py/label_test.py
+++ b/pw_bloat/py/label_test.py
@@ -17,7 +17,13 @@
 import unittest
 import os
 
-from pw_bloat.label import from_bloaty_csv, DataSourceMap
+from pw_bloat.label import from_bloaty_csv, DataSourceMap, Label
+
+LIST_LABELS = [
+    Label(name='main()', size=30, parents=tuple(['FLASH', '.code'])),
+    Label(name='foo()', size=100, parents=tuple(['RAM', '.heap'])),
+    Label(name='bar()', size=220, parents=tuple(['RAM', '.heap']))
+]
 
 
 def get_test_map():
@@ -34,20 +40,101 @@
 class LabelStructTest(unittest.TestCase):
     """Testing class for the label structs."""
     def test_data_source_total_size(self):
-        ds_map = DataSourceMap(["a", "b", "c"])
+        ds_map = DataSourceMap(['a', 'b', 'c'])
         self.assertEqual(ds_map.get_total_size(), 0)
 
     def test_data_source_single_insert_total_size(self):
-        ds_map = DataSourceMap(["a", "b", "c"])
-        ds_map.insert_label_hierachy(["FLASH", ".code", "main()"], 30)
+        ds_map = DataSourceMap(['a', 'b', 'c'])
+        ds_map.insert_label_hierachy(['FLASH', '.code', 'main()'], 30)
         self.assertEqual(ds_map.get_total_size(), 30)
 
     def test_data_source_multiple_insert_total_size(self):
-        ds_map = DataSourceMap(["a", "b", "c"])
-        ds_map.insert_label_hierachy(["FLASH", ".code", "main()"], 30)
-        ds_map.insert_label_hierachy(["RAM", ".code", "foo()"], 100)
+        ds_map = DataSourceMap(['a', 'b', 'c'])
+        ds_map.insert_label_hierachy(['FLASH', '.code', 'main()'], 30)
+        ds_map.insert_label_hierachy(['RAM', '.code', 'foo()'], 100)
         self.assertEqual(ds_map.get_total_size(), 130)
 
+    def test_parsing_generator_three_datasource_names(self):
+        ds_map = DataSourceMap(['a', 'b', 'c'])
+        for label in LIST_LABELS:
+            ds_map.insert_label_hierachy(
+                [label.parents[0], label.parents[1], label.name], label.size)
+        list_labels_three = [*LIST_LABELS, Label(name='total', size=350)]
+        for label_hiearchy in ds_map.labels():
+            self.assertIn(label_hiearchy, list_labels_three)
+        self.assertEqual(ds_map.get_total_size(), 350)
+
+    def test_parsing_generator_two_datasource_names(self):
+        ds_map = DataSourceMap(['a', 'b'])
+        ds_label_list = [
+            Label(name='main()', size=30, parents=tuple(['FLASH'])),
+            Label(name='foo()', size=100, parents=tuple(['RAM'])),
+            Label(name='bar()', size=220, parents=tuple(['RAM']))
+        ]
+        for label in ds_label_list:
+            ds_map.insert_label_hierachy([label.parents[0], label.name],
+                                         label.size)
+        list_labels_two = [*ds_label_list, Label(name='total', size=350)]
+        for label_hiearchy in ds_map.labels():
+            self.assertIn(label_hiearchy, list_labels_two)
+        self.assertEqual(ds_map.get_total_size(), 350)
+
+    def test_parsing_generator_specified_datasource_1(self):
+        ds_map = DataSourceMap(['a', 'b', 'c'])
+        for label in LIST_LABELS:
+            ds_map.insert_label_hierachy(
+                [label.parents[0], label.parents[1], label.name], label.size)
+        list_labels_ds_b = [
+            Label(name='.code', size=30, parents=tuple(['FLASH'])),
+            Label(name='.heap', size=320, parents=tuple(['RAM']))
+        ]
+        list_labels_ds_b += [Label(name='total', size=350)]
+        for label_hiearchy in ds_map.labels(1):
+            self.assertIn(label_hiearchy, list_labels_ds_b)
+        self.assertEqual(ds_map.get_total_size(), 350)
+
+    def test_parsing_generator_specified_datasource_str_2(self):
+        ds_map = DataSourceMap(['a', 'b', 'c'])
+        for label in LIST_LABELS:
+            ds_map.insert_label_hierachy(
+                [label.parents[0], label.parents[1], label.name], label.size)
+        list_labels_ds_a = [
+            Label(name='FLASH', size=30, parents=tuple([])),
+            Label(name='RAM', size=320, parents=tuple([]))
+        ]
+        list_labels_ds_a += [Label(name='total', size=350)]
+        for label_hiearchy in ds_map.labels(0):
+            self.assertIn(label_hiearchy, list_labels_ds_a)
+        self.assertEqual(ds_map.get_total_size(), 350)
+
+    def test_parsing_generator_specified_datasource_int(self):
+        ds_map = DataSourceMap(['a', 'b', 'c'])
+        for label in LIST_LABELS:
+            ds_map.insert_label_hierachy(
+                [label.parents[0], label.parents[1], label.name], label.size)
+        list_labels_ds_a = [
+            Label(name='FLASH', size=30, parents=tuple([])),
+            Label(name='RAM', size=320, parents=tuple([]))
+        ]
+        list_labels_ds_a += [Label(name='total', size=350)]
+        for label_hiearchy in ds_map.labels(0):
+            self.assertIn(label_hiearchy, list_labels_ds_a)
+        self.assertEqual(ds_map.get_total_size(), 350)
+
+    def test_parsing_generator_specified_datasource_int_2(self):
+        ds_map = DataSourceMap(['a', 'b', 'c'])
+        for label in LIST_LABELS:
+            ds_map.insert_label_hierachy(
+                [label.parents[0], label.parents[1], label.name], label.size)
+        list_labels_ds_b = [
+            Label(name='.code', size=30, parents=tuple(['FLASH'])),
+            Label(name='.heap', size=320, parents=tuple(['RAM']))
+        ]
+        list_labels_ds_b += [Label(name='total', size=350)]
+        for label_hiearchy in ds_map.labels(1):
+            self.assertIn(label_hiearchy, list_labels_ds_b)
+        self.assertEqual(ds_map.get_total_size(), 350)
+
 
 if __name__ == '__main__':
     unittest.main()
diff --git a/pw_bloat/py/pw_bloat/label.py b/pw_bloat/py/pw_bloat/label.py
index ba8aed1..08d475f 100644
--- a/pw_bloat/py/pw_bloat/label.py
+++ b/pw_bloat/py/pw_bloat/label.py
@@ -12,28 +12,35 @@
 # License for the specific language governing permissions and limitations under
 # the License.
 """
-LabelMap and Label moduled defines the data structure to hold
-size reports from Bloaty.
+The label module defines a class to store and manipulate size reports.
 """
 
 from collections import defaultdict
-from typing import Iterable, Dict, Tuple, List
+from dataclasses import dataclass
+from typing import Iterable, Dict, Sequence, Tuple, List, Generator, Optional
 
 import csv
 
 
+@dataclass
+class Label:
+    """Return type of DataSourceMap generator."""
+    name: str
+    size: int
+    parents: Tuple[str, ...] = ()
+
+
 class _LabelMap:
-    """Private module to store a parent label and all of its
-    child labels with its corresponding size in a nested dictionary."""
+    """Private module to hold parent and child labels with their size."""
     _label_map: Dict[str, Dict[str, int]]
 
     def __init__(self):
         self._label_map = defaultdict(lambda: defaultdict(int))
 
-    def remove(self, parent_label: str, label: str = None) -> None:
+    def remove(self, parent_label: str, child_label: str = None) -> None:
         """Delete entire parent label or the child label."""
-        if label:
-            del self._label_map[parent_label][label]
+        if child_label:
+            del self._label_map[parent_label][child_label]
         else:
             del self._label_map[parent_label]
 
@@ -41,10 +48,13 @@
         """Subtract the current LabelMap to the base."""
 
     def __getitem__(self, parent_label: str) -> Dict[str, int]:
-        """Allow indexing of a LabelMap using '[]' operators
-        by specifying a label to access."""
+        """Indexing LabelMap using '[]' operators by specifying a label."""
         return self._label_map[parent_label]
 
+    def map_generator(self) -> Generator:
+        for parent_label, label_dict in self._label_map.items():
+            yield parent_label, label_dict
+
 
 class _DataSource:
     """Private module to store a data source name with a _LabelMap."""
@@ -55,17 +65,26 @@
     def get_name(self) -> str:
         return self._name
 
-    def add_label_size(self, parent_label: str, label: str, size: int) -> None:
-        self._ds_label_map[parent_label][label] += size
+    def add_label_size(self, parent_label: str, child_label: str,
+                       size: int) -> None:
+        self._ds_label_map[parent_label][child_label] += size
 
     def __getitem__(self, parent_label: str) -> Dict[str, int]:
         return self._ds_label_map[parent_label]
 
+    def label_map_generator(self) -> Generator:
+        for parent_label, label_dict in self._ds_label_map.map_generator():
+            yield parent_label, label_dict
+
 
 class DataSourceMap:
-    """Module with an array of DataSources to organize a hierachy
-    of labels and their sizes. Includes a capacity array to hold regex
-    patterns for applying capacities to matching labels."""
+    """Module to store an array of DataSources and capacities.
+
+    An organize way to store a hierachy of labels and their sizes.
+    Includes a capacity array to hold regex patterns for applying
+    capacities to matching label names.
+
+    """
     def __init__(self, data_sources_names: Iterable[str]):
         self._data_sources = list(
             _DataSource(name) for name in ['base', *data_sources_names])
@@ -92,12 +111,51 @@
     def get_total_size(self) -> int:
         return self._data_sources[0]['__base__']['total']
 
+    def get_ds_names(self) -> Tuple[str, ...]:
+        """List of DataSource names for easy indexing and reference."""
+        return tuple(data_source.get_name()
+                     for data_source in self._data_sources[1:])
+
+    def labels(self, ds_index: Optional[int] = None) -> Iterable[Label]:
+        """Generator that yields a Label depending on specified data source.
+
+        Args:
+            ds_index: Integer index of target data source.
+
+        Returns:
+            Iterable Label objects.
+        """
+        ds_index = len(
+            self._data_sources) if ds_index is None else ds_index + 2
+        yield from self._per_data_source_generator(
+            tuple(), self._data_sources[1:ds_index])
+
+    def _per_data_source_generator(
+            self, parent_labels: Tuple[str, ...],
+            data_sources: Sequence[_DataSource]) -> Iterable[Label]:
+        """Recursive generator to return Label based off parent labels."""
+        for ds_index, curr_ds in enumerate(data_sources):
+            for parent_label, label_map in curr_ds.label_map_generator():
+                if not parent_labels:
+                    curr_parent = 'total'
+                else:
+                    curr_parent = parent_labels[-1]
+                if parent_label == curr_parent:
+                    for child_label, size in label_map.items():
+                        if len(data_sources) == 1:
+                            yield Label(child_label, size, parent_labels)
+                        else:
+                            yield from self._per_data_source_generator(
+                                (*parent_labels, child_label),
+                                data_sources[ds_index + 1:])
+
 
 def from_bloaty_csv(raw_csv: Iterable[str]) -> DataSourceMap:
     """Read in Bloaty CSV output and store in DataSourceMap."""
     reader = csv.reader(raw_csv)
     top_row = next(reader)
-    ds_map = DataSourceMap(top_row[:-2])
+    ds_map_csv = DataSourceMap(top_row[:-2])
+    vmsize_index = top_row.index('vmsize')
     for row in reader:
-        ds_map.insert_label_hierachy(row[:-2], int(row[-2]))
-    return ds_map
+        ds_map_csv.insert_label_hierachy(row[:-2], int(row[vmsize_index]))
+    return ds_map_csv