Add test case generation for usage extensions when loading keys

Add test cases validating that if a stored key only had the hash policy,
then after loading it psa_get_key_attributes reports that it also has the
message policy, and the key can be used with message functions.

Signed-off-by: gabor-mezei-arm <gabor.mezei@arm.com>
diff --git a/scripts/mbedtls_dev/macro_collector.py b/scripts/mbedtls_dev/macro_collector.py
index 6eb0d00..f9ef5f9 100644
--- a/scripts/mbedtls_dev/macro_collector.py
+++ b/scripts/mbedtls_dev/macro_collector.py
@@ -101,6 +101,7 @@
         self.kdf_algorithms = set() #type: Set[str]
         self.pake_algorithms = set() #type: Set[str]
         self.aead_algorithms = set() #type: Set[str]
+        self.sign_algorithms = set() #type: Set[str]
         # macro name -> list of argument names
         self.argspecs = {} #type: Dict[str, List[str]]
         # argument name -> list of values
@@ -135,6 +136,7 @@
         self.arguments_for['ka_alg'] = sorted(self.ka_algorithms)
         self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
         self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
+        self.arguments_for['sign_alg'] = sorted(self.sign_algorithms)
         self.arguments_for['curve'] = sorted(self.ecc_curves)
         self.arguments_for['group'] = sorted(self.dh_groups)
         self.arguments_for['persistence'] = sorted(self.persistence_levels)
@@ -368,11 +370,11 @@
             'hash_algorithm': [self.hash_algorithms],
             'mac_algorithm': [self.mac_algorithms],
             'cipher_algorithm': [],
-            'hmac_algorithm': [self.mac_algorithms],
+            'hmac_algorithm': [self.mac_algorithms, self.sign_algorithms],
             'aead_algorithm': [self.aead_algorithms],
             'key_derivation_algorithm': [self.kdf_algorithms],
             'key_agreement_algorithm': [self.ka_algorithms],
-            'asymmetric_signature_algorithm': [],
+            'asymmetric_signature_algorithm': [self.sign_algorithms],
             'asymmetric_signature_wildcard': [self.algorithms],
             'asymmetric_encryption_algorithm': [],
             'pake_algorithm': [self.pake_algorithms],
diff --git a/scripts/mbedtls_dev/psa_storage.py b/scripts/mbedtls_dev/psa_storage.py
index 4cd3dfe..ff2fdd4 100644
--- a/scripts/mbedtls_dev/psa_storage.py
+++ b/scripts/mbedtls_dev/psa_storage.py
@@ -107,6 +107,14 @@
     } #type: Dict[Expr, Expr]
     """The extendable usage flags with the corresponding extension flags."""
 
+    EXTENDABLE_USAGE_FLAGS_KEY_RESTRICTION = {
+        'PSA_KEY_USAGE_SIGN_HASH': '.*KEY_PAIR',
+        'PSA_KEY_USAGE_VERIFY_HASH': '.*KEY.*'
+    } #type: Dict[str, str]
+    """The key type filter for the extendable usage flags.
+    The filter is a regexp.
+    """
+
     def __init__(self, *,
                  version: Optional[int] = None,
                  id: Optional[int] = None, #pylint: disable=redefined-builtin
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index f1f3c42..5cf893b 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -233,10 +233,25 @@
 class StorageKey(psa_storage.Key):
     """Representation of a key for storage format testing."""
 
-    def __init__(self, *, description: str, **kwargs) -> None:
+    def __init__(
+            self,
+            description: str,
+            expected_usage: Optional[str] = None,
+            **kwargs
+    ) -> None:
+        """Prepare to generate a key.
+
+        * `description`: used for the the test case names
+        * `expected_usage`: the usage flags generated as the expected usage
+                            flags in the test cases. When testing usage
+                            extension the usage flags can differ in the
+                            generated key and the expected usage flags
+                            in the test cases.
+        """
         super().__init__(**kwargs)
         self.description = description #type: str
-        self.usage = self.original_usage #type: psa_storage.Expr
+        self.usage = psa_storage.as_expr(expected_usage) if expected_usage is not None else\
+                     self.original_usage  #type: psa_storage.Expr
 
 class StorageKeyBuilder:
     def __init__(self, usage_extension: bool) -> None:
@@ -475,7 +490,10 @@
     def __init__(self, info: Information) -> None:
         super().__init__(info, 0, False)
 
-    def all_keys_for_usage_flags(self) -> List[StorageKey]:
+    def all_keys_for_usage_flags(
+            self,
+            extra_desc: Optional[str] = None
+    ) -> List[StorageKey]:
         """Generate test keys covering usage flags."""
         # First generate keys without usage policy extension for
         # compatibility testing, then generate the keys with extension
@@ -492,6 +510,121 @@
         self.key_builder = prev_builder
         return keys
 
+    def keys_for_usage_extension(
+            self,
+            extendable: psa_storage.Expr,
+            alg: str,
+            key_type: str,
+            params: Optional[Iterable[str]] = None
+    ) -> List[StorageKey]:
+        """Generate test keys for the specified extendable usage flag,
+           algorithm and key type combination.
+        """
+        keys = [] #type: List[StorageKey]
+        kt = crypto_knowledge.KeyType(key_type, params)
+        for bits in kt.sizes_to_test():
+            extension = StorageKey.EXTENDABLE_USAGE_FLAGS[extendable]
+            usage_flags = 'PSA_KEY_USAGE_EXPORT'
+            material_usage_flags = usage_flags + ' | ' + extendable.string
+            expected_usage_flags = material_usage_flags + ' | ' + extension.string
+            alg2 = 0
+            key_material = kt.key_material(bits)
+            usage_expression = re.sub(r'PSA_KEY_USAGE_', r'', extendable.string)
+            alg_expression = re.sub(r'PSA_ALG_', r'', alg)
+            alg_expression = re.sub(r',', r', ', re.sub(r' +', r'', alg_expression))
+            key_type_expression = re.sub(r'\bPSA_(?:KEY_TYPE|ECC_FAMILY)_',
+                                      r'',
+                                      kt.expression)
+            description = 'extend {}: {} {} {}-bit'.format(
+                usage_expression, alg_expression, key_type_expression, bits)
+            keys.append(self.key_builder.build(
+                            version=self.version,
+                            id=1, lifetime=0x00000001,
+                            type=kt.expression, bits=bits,
+                            usage=material_usage_flags,
+                            expected_usage=expected_usage_flags,
+                            alg=alg, alg2=alg2,
+                            material=key_material,
+                            description=description))
+        return keys
+
+    def gather_key_types_for_sign_alg(self) -> Dict[str, List[str]]:
+        """Match possible key types for sign algorithms."""
+        # To create a valid combinaton both the algorithms and key types
+        # must be filtered. Pair them with keywords created from its names.
+        incompatible_alg_keyword = frozenset(['RAW', 'ANY', 'PURE'])
+        incompatible_key_type_keywords = frozenset(['MONTGOMERY'])
+        keyword_translation = {
+            'ECDSA': 'ECC',
+            'ED[0-9]*.*' : 'EDWARDS'
+        }
+        exclusive_keywords = {
+            'EDWARDS': 'ECC'
+        }
+        key_types = set(self.constructors.generate_expressions(
+                            self.constructors.key_types))
+        algorithms = set(self.constructors.generate_expressions(
+                            self.constructors.sign_algorithms))
+        alg_with_keys = {} #type: Dict[str, List[str]]
+        translation_table = str.maketrans('(', '_', ')')
+        for alg in algorithms:
+            # Generate keywords from the name of the algorithm
+            alg_keywords = set(alg.partition('(')[0].split(sep='_')[2:])
+            # Translate keywords for better matching with the key types
+            for keyword in alg_keywords.copy():
+                for pattern, replace in keyword_translation.items():
+                    if re.match(pattern, keyword):
+                        alg_keywords.remove(keyword)
+                        alg_keywords.add(replace)
+            # Filter out incompatible algortihms
+            if not alg_keywords.isdisjoint(incompatible_alg_keyword):
+                continue
+
+            for key_type in key_types:
+                # Generate keywords from the of the key type
+                key_type_keywords = set(key_type.translate(translation_table).split(sep='_')[3:])
+
+                # Remove ambigious keywords
+                for keyword1, keyword2 in exclusive_keywords.items():
+                    if keyword1 in key_type_keywords:
+                        key_type_keywords.remove(keyword2)
+
+                if key_type_keywords.isdisjoint(incompatible_key_type_keywords) and\
+                   not key_type_keywords.isdisjoint(alg_keywords):
+                    if alg in alg_with_keys:
+                        alg_with_keys[alg].append(key_type)
+                    else:
+                        alg_with_keys[alg] = [key_type]
+        return alg_with_keys
+
+    def all_keys_for_usage_extension(self) -> List[StorageKey]:
+        """Generate test keys for usage flag extensions."""
+        # Generate a key type and algorithm pair for each extendable usage
+        # flag to generate a valid key for exercising. The key is generated
+        # without usage extension to check the extension compatiblity.
+        keys = [] #type: List[StorageKey]
+        prev_builder = self.key_builder
+
+        # Generate the key without usage extension
+        self.key_builder = StorageKeyBuilder(usage_extension = False)
+        alg_with_keys = self.gather_key_types_for_sign_alg()
+        key_restrictions = StorageKey.EXTENDABLE_USAGE_FLAGS_KEY_RESTRICTION
+        # Walk through all combintion. The key types must be filtered to fit
+        # the specific usage flag.
+        keys += [key for usage in StorageKey.EXTENDABLE_USAGE_FLAGS.keys()
+                     for alg in sorted(alg_with_keys.keys())
+                     for key_type in sorted(filter(
+                            lambda kt: re.match(key_restrictions[usage.string], kt),
+                            alg_with_keys[alg]))
+                     for key in self.keys_for_usage_extension(usage, alg, key_type)]
+
+        self.key_builder = prev_builder
+        return keys
+
+    def generate_all_keys(self) -> List[StorageKey]:
+        keys = super().generate_all_keys()
+        keys += self.all_keys_for_usage_extension()
+        return keys
 
 class TestGenerator:
     """Generate test data."""