Hook up PSAMacroCollector to PSAMacroEnumerator

Make it possible to enumerate the key types, algorithms, etc.
collected by PSAMacroCollector.

This commit ensures that all fields of PSAMacroEnumerator are filled
by code inspection. Testing of the result may reveal more work to be
done in later commits.

Signed-off-by: Gilles Peskine <Gilles.Peskine@arm.com>
diff --git a/scripts/generate_psa_constants.py b/scripts/generate_psa_constants.py
index d0d1f3f..ff07ecd 100755
--- a/scripts/generate_psa_constants.py
+++ b/scripts/generate_psa_constants.py
@@ -304,7 +304,7 @@
 
     def _make_key_usage_code(self):
         return '\n'.join([self._make_bit_test('usage', bit)
-                          for bit in sorted(self.key_usages)])
+                          for bit in sorted(self.key_usage_flags)])
 
     def write_file(self, output_file):
         """Generate the pretty-printer function code from the gathered
diff --git a/scripts/mbedtls_dev/macro_collector.py b/scripts/mbedtls_dev/macro_collector.py
index c9e6ec3..a2192ba 100644
--- a/scripts/mbedtls_dev/macro_collector.py
+++ b/scripts/mbedtls_dev/macro_collector.py
@@ -126,7 +126,7 @@
         return itertools.chain(*map(self.distribute_arguments, names))
 
 
-class PSAMacroCollector:
+class PSAMacroCollector(PSAMacroEnumerator):
     """Collect PSA crypto macro definitions from C header files.
     """
 
@@ -138,18 +138,11 @@
         * include_intermediate: if true, include intermediate macros such as
           PSA_XXX_BASE that do not designate semantic values.
         """
+        super().__init__()
         self.include_intermediate = include_intermediate
-        self.statuses = set() #type: Set[str]
-        self.key_types = set() #type: Set[str]
         self.key_types_from_curve = {} #type: Dict[str, str]
         self.key_types_from_group = {} #type: Dict[str, str]
-        self.ecc_curves = set() #type: Set[str]
-        self.dh_groups = set() #type: Set[str]
-        self.algorithms = set() #type: Set[str]
-        self.hash_algorithms = set() #type: Set[str]
-        self.ka_algorithms = set() #type: Set[str]
         self.algorithms_from_hash = {} #type: Dict[str, str]
-        self.key_usages = set() #type: Set[str]
 
     def is_internal_name(self, name: str) -> bool:
         """Whether this is an internal macro. Internal macros will be skipped."""
@@ -160,6 +153,30 @@
                 return True
         return name.endswith('_FLAG') or name.endswith('_MASK')
 
+    def record_algorithm_subtype(self, name: str, expansion: str) -> None:
+        """Record the subtype of an algorithm constructor.
+
+        Given a ``PSA_ALG_xxx`` macro name and its expansion, if the algorithm
+        is of a subtype that is tracked in its own set, add it to the relevant
+        set.
+        """
+        # This code is very ad hoc and fragile. It should be replaced by
+        # something more robust.
+        if re.match(r'MAC(?:_|\Z)', name):
+            self.mac_algorithms.add(name)
+        elif re.match(r'KDF(?:_|\Z)', name):
+            self.kdf_algorithms.add(name)
+        elif re.search(r'0x020000[0-9A-Fa-f]{2}', expansion):
+            self.hash_algorithms.add(name)
+        elif re.search(r'0x03[0-9A-Fa-f]{6}', expansion):
+            self.mac_algorithms.add(name)
+        elif re.search(r'0x05[0-9A-Fa-f]{6}', expansion):
+            self.aead_algorithms.add(name)
+        elif re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion):
+            self.ka_algorithms.add(name)
+        elif re.search(r'0x08[0-9A-Fa-f]{6}', expansion):
+            self.kdf_algorithms.add(name)
+
     # "#define" followed by a macro name with either no parameters
     # or a single parameter and a non-empty expansion.
     # Grab the macro name in group 1, the parameter name if any in group 2
@@ -180,6 +197,8 @@
             return
         name, parameter, expansion = m.groups()
         expansion = re.sub(r'/\*.*?\*/|//.*', r' ', expansion)
+        if parameter:
+            self.argspecs[name] = [parameter]
         if re.match(self._deprecated_definition_re, expansion):
             # Skip deprecated values, which are assumed to be
             # backward compatibility aliases that share
@@ -207,12 +226,7 @@
                 # Ad hoc skipping of duplicate names for some numerical values
                 return
             self.algorithms.add(name)
-            # Ad hoc detection of hash algorithms
-            if re.search(r'0x020000[0-9A-Fa-f]{2}', expansion):
-                self.hash_algorithms.add(name)
-            # Ad hoc detection of key agreement algorithms
-            if re.search(r'0x09[0-9A-Fa-f]{2}0000', expansion):
-                self.ka_algorithms.add(name)
+            self.record_algorithm_subtype(name, expansion)
         elif name.startswith('PSA_ALG_') and parameter == 'hash_alg':
             if name in ['PSA_ALG_DSA', 'PSA_ALG_ECDSA']:
                 # A naming irregularity
@@ -221,7 +235,7 @@
                 tester = name[:8] + 'IS_' + name[8:]
             self.algorithms_from_hash[name] = tester
         elif name.startswith('PSA_KEY_USAGE_') and not parameter:
-            self.key_usages.add(name)
+            self.key_usage_flags.add(name)
         else:
             # Other macro without parameter
             return