Refactor key generation

Remove the key builder and use iterator instead of lists.

Signed-off-by: gabor-mezei-arm <gabor.mezei@arm.com>
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
old mode 100644
new mode 100755
index 25d5c9d..fe865ae
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -252,14 +252,6 @@
         self.usage = psa_storage.as_expr(expected_usage) if expected_usage is not None else\
                      self.original_usage  #type: psa_storage.Expr
 
-class StorageKeyBuilder:
-    # pylint: disable=too-few-public-methods
-    def __init__(self, usage_extension: bool) -> None:
-        self.usage_extension = usage_extension #type: bool
-
-    def build(self, **kwargs) -> StorageKey:
-        return StorageKey(implicit_usage=self.usage_extension, **kwargs)
-
 class StorageFormat:
     """Storage format stability test cases."""
 
@@ -276,7 +268,6 @@
         self.constructors = info.constructors #type: macro_collector.PSAMacroEnumerator
         self.version = version #type: int
         self.forward = forward #type: bool
-        self.key_builder = StorageKeyBuilder(usage_extension=True) #type: StorageKeyBuilder
 
     def make_test_case(self, key: StorageKey) -> test_case.TestCase:
         """Construct a storage format test case for the given key.
@@ -329,19 +320,17 @@
                        r'', short)
         short = re.sub(r'PSA_KEY_[A-Z]+_', r'', short)
         description = 'lifetime: ' + short
-        key = self.key_builder.build(version=self.version,
-                                     id=1, lifetime=lifetime,
-                                     type='PSA_KEY_TYPE_RAW_DATA', bits=8,
-                                     usage='PSA_KEY_USAGE_EXPORT', alg=0, alg2=0,
-                                     material=b'L',
-                                     description=description)
-        return key
+        return StorageKey(version=self.version,
+                          id=1, lifetime=lifetime,
+                          type='PSA_KEY_TYPE_RAW_DATA', bits=8,
+                          usage='PSA_KEY_USAGE_EXPORT', alg=0, alg2=0,
+                          material=b'L',
+                          description=description)
 
-    def all_keys_for_lifetimes(self) -> List[StorageKey]:
+    def all_keys_for_lifetimes(self) -> Iterator[StorageKey]:
         """Generate test keys covering lifetimes."""
         lifetimes = sorted(self.constructors.lifetimes)
         expressions = self.constructors.generate_expressions(lifetimes)
-        keys = [] #type List[StorageKey]
         for lifetime in expressions:
             # Don't attempt to create or load a volatile key in storage
             if 'VOLATILE' in lifetime:
@@ -350,67 +339,67 @@
             # but do attempt to load one.
             if 'READ_ONLY' in lifetime and self.forward:
                 continue
-            keys.append(self.key_for_lifetime(lifetime))
-        return keys
+            yield self.key_for_lifetime(lifetime)
 
     def key_for_usage_flags(
             self,
             usage_flags: List[str],
             short: Optional[str] = None,
-            extra_desc: Optional[str] = None
-    ) -> StorageKey:
+            test_implicit_usage: Optional[bool] = False
+    ) -> Iterator[StorageKey]:
         """Construct a test key for the given key usage."""
         usage = ' | '.join(usage_flags) if usage_flags else '0'
         if short is None:
             short = re.sub(r'\bPSA_KEY_USAGE_', r'', usage)
-        extra_desc = ' ' + extra_desc if extra_desc else ''
+        extra_desc = ' with implication' if test_implicit_usage else ''
         description = 'usage' + extra_desc + ': ' + short
-        return self.key_builder.build(version=self.version,
-                                      id=1, lifetime=0x00000001,
-                                      type='PSA_KEY_TYPE_RAW_DATA', bits=8,
-                                      usage=usage, alg=0, alg2=0,
-                                      material=b'K',
-                                      description=description)
+        yield StorageKey(version=self.version,
+                         id=1, lifetime=0x00000001,
+                         type='PSA_KEY_TYPE_RAW_DATA', bits=8,
+                         usage=usage, alg=0, alg2=0,
+                         material=b'K',
+                         description=description,
+                         implicit_usage=True)
+        if test_implicit_usage:
+            description = 'usage without implication' + ': ' + short
+            yield StorageKey(version=self.version,
+                             id=1, lifetime=0x00000001,
+                             type='PSA_KEY_TYPE_RAW_DATA', bits=8,
+                             usage=usage, alg=0, alg2=0,
+                             material=b'K',
+                             description=description,
+                             implicit_usage=False)
 
-    def generate_keys_for_usage_flags(
-            self,
-            extra_desc: Optional[str] = None
-    ) -> List[StorageKey]:
+
+    def generate_keys_for_usage_flags(self, **kwargs) -> Iterator[StorageKey]:
         """Generate test keys covering usage flags."""
         known_flags = sorted(self.constructors.key_usage_flags)
-        keys = [] #type List[StorageKey]
-        keys.append(self.key_for_usage_flags(['0'], extra_desc=extra_desc))
-        keys += [self.key_for_usage_flags([usage_flag], extra_desc=extra_desc)
-                 for usage_flag in known_flags]
-        keys += [self.key_for_usage_flags([flag1, flag2], extra_desc=extra_desc)
-                 for flag1, flag2 in zip(known_flags,
-                                         known_flags[1:] + [known_flags[0]])]
-        return keys
+        yield from self.key_for_usage_flags(['0'], **kwargs)
+        for usage_flag in known_flags:
+            yield from self.key_for_usage_flags([usage_flag], **kwargs)
+        for flag1, flag2 in zip(known_flags,
+                                known_flags[1:] + [known_flags[0]]):
+            yield from self.key_for_usage_flags([flag1, flag2], **kwargs)
 
-    def generate_key_for_all_usage_flags(self) -> StorageKey:
+    def generate_key_for_all_usage_flags(self) -> Iterator[StorageKey]:
         known_flags = sorted(self.constructors.key_usage_flags)
-        return self.key_for_usage_flags(known_flags, short='all known')
+        yield from self.key_for_usage_flags(known_flags, short='all known')
 
-    def all_keys_for_usage_flags(
-            self,
-            extra_desc: Optional[str] = None
-    ) -> List[StorageKey]:
-        keys = self.generate_keys_for_usage_flags(extra_desc=extra_desc)
-        keys.append(self.generate_key_for_all_usage_flags())
-        return keys
+    def all_keys_for_usage_flags(self) -> Iterator[StorageKey]:
+        yield from self.generate_keys_for_usage_flags()
+        yield from self.generate_key_for_all_usage_flags()
 
     def keys_for_type(
             self,
             key_type: str,
             params: Optional[Iterable[str]] = None
-    ) -> List[StorageKey]:
+    ) -> Iterator[StorageKey]:
         """Generate test keys for the given key type.
 
         For key types that depend on a parameter (e.g. elliptic curve family),
         `param` is the parameter to pass to the constructor. Only a single
         parameter is supported.
         """
-        keys = [] #type: List[StorageKey]
         kt = crypto_knowledge.KeyType(key_type, params)
         for bits in kt.sizes_to_test():
             usage_flags = 'PSA_KEY_USAGE_EXPORT'
@@ -421,22 +410,20 @@
                                       r'',
                                       kt.expression)
             description = 'type: {} {}-bit'.format(short_expression, bits)
-            keys.append(self.key_builder.build(version=self.version,
-                                               id=1, lifetime=0x00000001,
-                                               type=kt.expression, bits=bits,
-                                               usage=usage_flags, alg=alg, alg2=alg2,
-                                               material=key_material,
-                                               description=description))
-        return keys
+            yield StorageKey(version=self.version,
+                             id=1, lifetime=0x00000001,
+                             type=kt.expression, bits=bits,
+                             usage=usage_flags, alg=alg, alg2=alg2,
+                             material=key_material,
+                             description=description)
 
-    def all_keys_for_types(self) -> List[StorageKey]:
+    def all_keys_for_types(self) -> Iterator[StorageKey]:
         """Generate test keys covering key types and their representations."""
         key_types = sorted(self.constructors.key_types)
-        return [key
-                for key_type in self.constructors.generate_expressions(key_types)
-                for key in self.keys_for_type(key_type)]
+        for key_type in self.constructors.generate_expressions(key_types):
+            yield from self.keys_for_type(key_type)
 
-    def keys_for_algorithm(self, alg: str) -> List[StorageKey]:
+    def keys_for_algorithm(self, alg: str) -> Iterator[StorageKey]:
         """Generate test keys for the specified algorithm."""
         # For now, we don't have information on the compatibility of key
         # types and algorithms. So we just test the encoding of algorithms,
@@ -444,26 +431,24 @@
         descr = re.sub(r'PSA_ALG_', r'', alg)
         descr = re.sub(r',', r', ', re.sub(r' +', r'', descr))
         usage = 'PSA_KEY_USAGE_EXPORT'
-        key1 = self.key_builder.build(version=self.version,
-                                      id=1, lifetime=0x00000001,
-                                      type='PSA_KEY_TYPE_RAW_DATA', bits=8,
-                                      usage=usage, alg=alg, alg2=0,
-                                      material=b'K',
-                                      description='alg: ' + descr)
-        key2 = self.key_builder.build(version=self.version,
-                                      id=1, lifetime=0x00000001,
-                                      type='PSA_KEY_TYPE_RAW_DATA', bits=8,
-                                      usage=usage, alg=0, alg2=alg,
-                                      material=b'L',
-                                      description='alg2: ' + descr)
-        return [key1, key2]
+        yield StorageKey(version=self.version,
+                         id=1, lifetime=0x00000001,
+                         type='PSA_KEY_TYPE_RAW_DATA', bits=8,
+                         usage=usage, alg=alg, alg2=0,
+                         material=b'K',
+                         description='alg: ' + descr)
+        yield StorageKey(version=self.version,
+                         id=1, lifetime=0x00000001,
+                         type='PSA_KEY_TYPE_RAW_DATA', bits=8,
+                         usage=usage, alg=0, alg2=alg,
+                         material=b'L',
+                         description='alg2: ' + descr)
 
-    def all_keys_for_algorithms(self) -> List[StorageKey]:
+    def all_keys_for_algorithms(self) -> Iterator[StorageKey]:
         """Generate test keys covering algorithm encodings."""
         algorithms = sorted(self.constructors.algorithms)
-        return [key
-                for alg in self.constructors.generate_expressions(algorithms)
-                for key in self.keys_for_algorithm(alg)]
+        for alg in self.constructors.generate_expressions(algorithms):
+            yield from self.keys_for_algorithm(alg)
 
     def generate_all_keys(self) -> List[StorageKey]:
         """Generate all keys for the test cases."""
@@ -474,18 +459,19 @@
         keys += self.all_keys_for_algorithms()
         return keys
 
-    def all_test_cases(self) -> List[test_case.TestCase]:
+    def all_test_cases(self) -> Iterator[test_case.TestCase]:
         """Generate all storage format test cases."""
         # First build a list of all keys, then construct all the corresponding
         # test cases. This allows all required information to be obtained in
         # one go, which is a significant performance gain as the information
         # includes numerical values obtained by compiling a C program.
-        keys = self.generate_all_keys()
-
-        # Skip keys with a non-default location, because they
-        # require a driver and we currently have no mechanism to
-        # determine whether a driver is available.
-        return [self.make_test_case(key) for key in keys if key.location_value() == 0]
+        for key in self.generate_all_keys():
+            if key.location_value() != 0:
+                # Skip keys with a non-default location, because they
+                # require a driver and we currently have no mechanism to
+                # determine whether a driver is available.
+                continue
+            yield self.make_test_case(key)
 
 class StorageFormatForward(StorageFormat):
     """Storage format stability test cases for forward compatibility."""
@@ -499,29 +485,10 @@
     def __init__(self, info: Information) -> None:
         super().__init__(info, 0, False)
 
-    def all_keys_for_usage_flags(
-            self,
-            extra_desc: Optional[str] = None
-    ) -> List[StorageKey]:
+    def all_keys_for_usage_flags(self) -> Iterator[StorageKey]:
         """Generate test keys covering usage flags."""
-        # First generate keys without usage policy extension for
-        # compatibility testing, then generate the keys with extension
-        # to check the extension is working. Finally generate key for all known
-        # usage flag which needs to be separted because it is not affected by
-        # usage extension.
-        keys = [] #type: List[StorageKey]
-        prev_builder = self.key_builder
-
-        self.key_builder = StorageKeyBuilder(usage_extension=False)
-        keys += self.generate_keys_for_usage_flags(extra_desc='without extension')
-
-        self.key_builder = StorageKeyBuilder(usage_extension=True)
-        keys += self.generate_keys_for_usage_flags(extra_desc='with extension')
-
-        keys.append(self.generate_key_for_all_usage_flags())
-
-        self.key_builder = prev_builder
-        return keys
+        yield from self.generate_keys_for_usage_flags(test_implicit_usage=True)
+        yield from self.generate_key_for_all_usage_flags()
 
     def keys_for_implicit_usage(
             self,
@@ -529,12 +496,11 @@
             alg: str,
             key_type: str,
             params: Optional[Iterable[str]] = None
-    ) -> List[StorageKey]:
+    ) -> StorageKey:
         # pylint: disable=too-many-locals
         """Generate test keys for the specified implicit usage flag,
            algorithm and key type combination.
         """
-        keys = [] #type: List[StorageKey]
         kt = crypto_knowledge.KeyType(key_type, params)
         bits = kt.sizes_to_test()[0]
         implicit_usage = StorageKey.IMPLICIT_USAGE_FLAGS[implyer_usage]
@@ -551,15 +517,15 @@
                                      kt.expression)
         description = 'implied by {}: {} {} {}-bit'.format(
             usage_expression, alg_expression, key_type_expression, bits)
-        keys.append(self.key_builder.build(version=self.version,
-                                           id=1, lifetime=0x00000001,
-                                           type=kt.expression, bits=bits,
-                                           usage=material_usage_flags,
-                                           expected_usage=expected_usage_flags,
-                                           alg=alg, alg2=alg2,
-                                           material=key_material,
-                                           description=description))
-        return keys
+        return StorageKey(version=self.version,
+                          id=1, lifetime=0x00000001,
+                          type=kt.expression, bits=bits,
+                          usage=material_usage_flags,
+                          expected_usage=expected_usage_flags,
+                          alg=alg, alg2=alg2,
+                          material=key_material,
+                          description=description,
+                          implicit_usage=False)
 
     def gather_key_types_for_sign_alg(self) -> Dict[str, List[str]]:
         # pylint: disable=too-many-locals
@@ -609,29 +575,20 @@
                         alg_with_keys[alg] = [key_type]
         return alg_with_keys
 
-    def all_keys_for_implicit_usage(self) -> List[StorageKey]:
+    def all_keys_for_implicit_usage(self) -> Iterator[StorageKey]:
         """Generate test keys for usage flag extensions."""
         # Generate a key type and algorithm pair for each extendable usage
         # flag to generate a valid key for exercising. The key is generated
         # without usage extension to check the extension compatiblity.
-        keys = [] #type: List[StorageKey]
-        prev_builder = self.key_builder
-
-        # Generate the keys without usage extension
-        self.key_builder = StorageKeyBuilder(usage_extension=False)
         alg_with_keys = self.gather_key_types_for_sign_alg()
         key_filter = StorageKey.IMPLICIT_USAGE_FLAGS_KEY_RESTRICTION
 
-        # Walk through all combintion. The key types must be filtered to fit
-        # the specific usage flag.
-        keys += [key
-                 for usage in sorted(StorageKey.IMPLICIT_USAGE_FLAGS, key=str)
-                 for alg in sorted(alg_with_keys)
-                 for key_type in sorted(alg_with_keys[alg]) if re.match(key_filter[usage], key_type)
-                 for key in self.keys_for_implicit_usage(usage, alg, key_type)]
-
-        self.key_builder = prev_builder
-        return keys
+        for usage in sorted(StorageKey.IMPLICIT_USAGE_FLAGS, key=str):
+            for alg in sorted(alg_with_keys):
+                for key_type in sorted(alg_with_keys[alg]):
+                    # The key types must be filtered to fit the specific usage flag.
+                    if re.match(key_filter[usage], key_type):
+                        yield self.keys_for_implicit_usage(usage, alg, key_type)
 
     def generate_all_keys(self) -> List[StorageKey]:
         keys = super().generate_all_keys()