Refactor handlibg of the key usage flags

Move implicit usage flags handling to the StorageKey class.
Create a subclass for test case data.

Signed-off-by: gabor-mezei-arm <gabor.mezei@arm.com>
diff --git a/scripts/mbedtls_dev/psa_storage.py b/scripts/mbedtls_dev/psa_storage.py
index 88992a6..45f0380 100644
--- a/scripts/mbedtls_dev/psa_storage.py
+++ b/scripts/mbedtls_dev/psa_storage.py
@@ -101,12 +101,6 @@
     LATEST_VERSION = 0
     """The latest version of the storage format."""
 
-    IMPLICIT_USAGE_FLAGS = {
-        'PSA_KEY_USAGE_SIGN_HASH': 'PSA_KEY_USAGE_SIGN_MESSAGE',
-        'PSA_KEY_USAGE_VERIFY_HASH': 'PSA_KEY_USAGE_VERIFY_MESSAGE'
-    } #type: Dict[str, str]
-    """Mapping of usage flags to the flags that they imply."""
-
     def __init__(self, *,
                  version: Optional[int] = None,
                  id: Optional[int] = None, #pylint: disable=redefined-builtin
@@ -114,27 +108,18 @@
                  type: Exprable, #pylint: disable=redefined-builtin
                  bits: int,
                  usage: Exprable, alg: Exprable, alg2: Exprable,
-                 material: bytes, #pylint: disable=used-before-assignment
-                 implicit_usage: bool = True
+                 material: bytes #pylint: disable=used-before-assignment
                 ) -> None:
         self.version = self.LATEST_VERSION if version is None else version
         self.id = id #pylint: disable=invalid-name #type: Optional[int]
         self.lifetime = as_expr(lifetime) #type: Expr
         self.type = as_expr(type) #type: Expr
         self.bits = bits #type: int
-        self.original_usage = as_expr(usage) #type: Expr
-        self.updated_usage = self.original_usage #type: Expr
+        self.usage = as_expr(usage) #type: Expr
         self.alg = as_expr(alg) #type: Expr
         self.alg2 = as_expr(alg2) #type: Expr
         self.material = material #type: bytes
 
-        if implicit_usage:
-            for flag, extension in self.IMPLICIT_USAGE_FLAGS.items():
-                if self.original_usage.value() & Expr(flag).value() and \
-                   self.original_usage.value() & Expr(extension).value() == 0:
-                    self.updated_usage = Expr(self.updated_usage.string +
-                                              ' | ' + extension)
-
     MAGIC = b'PSA\000KEY\000'
 
     @staticmethod
@@ -166,7 +151,7 @@
         if self.version == 0:
             attributes = self.pack('LHHLLL',
                                    self.lifetime, self.type, self.bits,
-                                   self.updated_usage, self.alg, self.alg2)
+                                   self.usage, self.alg, self.alg2)
             material = self.pack('L', len(self.material)) + self.material
         else:
             raise NotImplementedError
diff --git a/tests/scripts/generate_psa_tests.py b/tests/scripts/generate_psa_tests.py
index 4e2842c..3eb7c4b 100755
--- a/tests/scripts/generate_psa_tests.py
+++ b/tests/scripts/generate_psa_tests.py
@@ -233,24 +233,50 @@
 class StorageKey(psa_storage.Key):
     """Representation of a key for storage format testing."""
 
+    IMPLICIT_USAGE_FLAGS = {
+        'PSA_KEY_USAGE_SIGN_HASH': 'PSA_KEY_USAGE_SIGN_MESSAGE',
+        'PSA_KEY_USAGE_VERIFY_HASH': 'PSA_KEY_USAGE_VERIFY_MESSAGE'
+    } #type: Dict[str, str]
+    """Mapping of usage flags to the flags that they imply."""
+
+    def __init__(
+            self,
+            usage: str,
+            without_implicit_usage: Optional[bool] = False,
+            **kwargs
+    ) -> None:
+        """Prepare to generate a key.
+
+        * `usage`                 : The usage flags used for the key.
+        * `without_implicit_usage`: Flag to defide to apply the usage extension
+        """
+        super().__init__(usage=usage,**kwargs)
+
+        if not without_implicit_usage:
+            for flag, implicit in self.IMPLICIT_USAGE_FLAGS.items():
+                if self.usage.value() & psa_storage.Expr(flag).value() and \
+                   self.usage.value() & psa_storage.Expr(implicit).value() == 0:
+                    self.usage = psa_storage.Expr(self.usage.string + ' | ' + implicit)
+
+class StorageTestData(StorageKey):
+    """Representation of test case data for storage format testing."""
+
     def __init__(
             self,
             description: str,
             expected_usage: Optional[str] = None,
             **kwargs
     ) -> None:
-        """Prepare to generate a key.
+        """Prepare to generate test data
 
-        * `description`: used for the the test case names
-        * `implicit_usage`: the usage flags generated as the expected usage
-                            flags in the test cases. When testing implicit
-                            usage flags, they can differ in the generated keys
-                            and the expected usage flags in the test cases.
+        * `description`   : used for the the test case names
+        * `expected_usage`: the usage flags generated as the expected usage flags
+                            in the test cases. CAn differ from the usage flags
+                            stored in the keys because of the usage flags extension.
         """
         super().__init__(**kwargs)
         self.description = description #type: str
-        self.usage = psa_storage.as_expr(expected_usage) if expected_usage is not None else\
-                     self.original_usage  #type: psa_storage.Expr
+        self.expected_usage = expected_usage if expected_usage else self.usage.string #type: str
 
 class StorageFormat:
     """Storage format stability test cases."""
@@ -269,7 +295,7 @@
         self.version = version #type: int
         self.forward = forward #type: bool
 
-    def make_test_case(self, key: StorageKey) -> test_case.TestCase:
+    def make_test_case(self, key: StorageTestData) -> test_case.TestCase:
         """Construct a storage format test case for the given key.
 
         If ``forward`` is true, generate a forward compatibility test case:
@@ -283,7 +309,7 @@
         tc.set_description('PSA storage {}: {}'.format(verb, key.description))
         dependencies = automatic_dependencies(
             key.lifetime.string, key.type.string,
-            key.usage.string, key.alg.string, key.alg2.string,
+            key.expected_usage, key.alg.string, key.alg2.string,
         )
         dependencies = finish_family_dependencies(dependencies, key.bits)
         tc.set_dependencies(dependencies)
@@ -304,7 +330,7 @@
             extra_arguments = [' | '.join(flags) if flags else '0']
         tc.set_arguments([key.lifetime.string,
                           key.type.string, str(key.bits),
-                          key.usage.string, key.alg.string, key.alg2.string,
+                          key.expected_usage, key.alg.string, key.alg2.string,
                           '"' + key.material.hex() + '"',
                           '"' + key.hex() + '"',
                           *extra_arguments])
@@ -313,21 +339,22 @@
     def key_for_lifetime(
             self,
             lifetime: str,
-    ) -> StorageKey:
+    ) -> StorageTestData:
         """Construct a test key for the given lifetime."""
         short = lifetime
         short = re.sub(r'PSA_KEY_LIFETIME_FROM_PERSISTENCE_AND_LOCATION',
                        r'', short)
         short = re.sub(r'PSA_KEY_[A-Z]+_', r'', short)
         description = 'lifetime: ' + short
-        return StorageKey(version=self.version,
-                          id=1, lifetime=lifetime,
-                          type='PSA_KEY_TYPE_RAW_DATA', bits=8,
-                          usage='PSA_KEY_USAGE_EXPORT', alg=0, alg2=0,
-                          material=b'L',
-                          description=description)
+        key = StorageTestData(version=self.version,
+                              id=1, lifetime=lifetime,
+                              type='PSA_KEY_TYPE_RAW_DATA', bits=8,
+                              usage='PSA_KEY_USAGE_EXPORT', alg=0, alg2=0,
+                              material=b'L',
+                              description=description)
+        return key
 
-    def all_keys_for_lifetimes(self) -> Iterator[StorageKey]:
+    def all_keys_for_lifetimes(self) -> Iterator[StorageTestData]:
         """Generate test keys covering lifetimes."""
         lifetimes = sorted(self.constructors.lifetimes)
         expressions = self.constructors.generate_expressions(lifetimes)
@@ -346,32 +373,34 @@
             usage_flags: List[str],
             short: Optional[str] = None,
             test_implicit_usage: Optional[bool] = False
-    ) -> Iterator[StorageKey]:
+    ) -> Iterator[StorageTestData]:
         """Construct a test key for the given key usage."""
         usage = ' | '.join(usage_flags) if usage_flags else '0'
         if short is None:
             short = re.sub(r'\bPSA_KEY_USAGE_', r'', usage)
         extra_desc = ' with implication' if test_implicit_usage else ''
         description = 'usage' + extra_desc + ': ' + short
-        yield StorageKey(version=self.version,
-                         id=1, lifetime=0x00000001,
-                         type='PSA_KEY_TYPE_RAW_DATA', bits=8,
-                         usage=usage, alg=0, alg2=0,
-                         material=b'K',
-                         description=description,
-                         implicit_usage=True)
+        key1 = StorageTestData(version=self.version,
+                               id=1, lifetime=0x00000001,
+                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
+                               expected_usage=usage,
+                               usage=usage, alg=0, alg2=0,
+                               material=b'K',
+                               description=description)
+        yield key1
+
         if test_implicit_usage:
             description = 'usage without implication' + ': ' + short
-            yield StorageKey(version=self.version,
-                             id=1, lifetime=0x00000001,
-                             type='PSA_KEY_TYPE_RAW_DATA', bits=8,
-                             usage=usage, alg=0, alg2=0,
-                             material=b'K',
-                             description=description,
-                             implicit_usage=False)
+            key2 = StorageTestData(version=self.version,
+                                   id=1, lifetime=0x00000001,
+                                   type='PSA_KEY_TYPE_RAW_DATA', bits=8,
+                                   without_implicit_usage=True,
+                                   usage=usage, alg=0, alg2=0,
+                                   material=b'K',
+                                   description=description)
+            yield key2
 
-
-    def generate_keys_for_usage_flags(self, **kwargs) -> Iterator[StorageKey]:
+    def generate_keys_for_usage_flags(self, **kwargs) -> Iterator[StorageTestData]:
         """Generate test keys covering usage flags."""
         known_flags = sorted(self.constructors.key_usage_flags)
         yield from self.key_for_usage_flags(['0'], **kwargs)
@@ -381,11 +410,11 @@
                                 known_flags[1:] + [known_flags[0]]):
             yield from self.key_for_usage_flags([flag1, flag2], **kwargs)
 
-    def generate_key_for_all_usage_flags(self) -> Iterator[StorageKey]:
+    def generate_key_for_all_usage_flags(self) -> Iterator[StorageTestData]:
         known_flags = sorted(self.constructors.key_usage_flags)
         yield from self.key_for_usage_flags(known_flags, short='all known')
 
-    def all_keys_for_usage_flags(self) -> Iterator[StorageKey]:
+    def all_keys_for_usage_flags(self) -> Iterator[StorageTestData]:
         yield from self.generate_keys_for_usage_flags()
         yield from self.generate_key_for_all_usage_flags()
 
@@ -393,7 +422,7 @@
             self,
             key_type: str,
             params: Optional[Iterable[str]] = None
-    ) -> Iterator[StorageKey]:
+    ) -> Iterator[StorageTestData]:
         """Generate test keys for the given key type.
 
         For key types that depend on a parameter (e.g. elliptic curve family),
@@ -410,20 +439,21 @@
                                       r'',
                                       kt.expression)
             description = 'type: {} {}-bit'.format(short_expression, bits)
-            yield StorageKey(version=self.version,
-                             id=1, lifetime=0x00000001,
-                             type=kt.expression, bits=bits,
-                             usage=usage_flags, alg=alg, alg2=alg2,
-                             material=key_material,
-                             description=description)
+            key = StorageTestData(version=self.version,
+                                  id=1, lifetime=0x00000001,
+                                  type=kt.expression, bits=bits,
+                                  usage=usage_flags, alg=alg, alg2=alg2,
+                                  material=key_material,
+                                  description=description)
+            yield key
 
-    def all_keys_for_types(self) -> Iterator[StorageKey]:
+    def all_keys_for_types(self) -> Iterator[StorageTestData]:
         """Generate test keys covering key types and their representations."""
         key_types = sorted(self.constructors.key_types)
         for key_type in self.constructors.generate_expressions(key_types):
             yield from self.keys_for_type(key_type)
 
-    def keys_for_algorithm(self, alg: str) -> Iterator[StorageKey]:
+    def keys_for_algorithm(self, alg: str) -> Iterator[StorageTestData]:
         """Generate test keys for the specified algorithm."""
         # For now, we don't have information on the compatibility of key
         # types and algorithms. So we just test the encoding of algorithms,
@@ -431,28 +461,30 @@
         descr = re.sub(r'PSA_ALG_', r'', alg)
         descr = re.sub(r',', r', ', re.sub(r' +', r'', descr))
         usage = 'PSA_KEY_USAGE_EXPORT'
-        yield StorageKey(version=self.version,
-                         id=1, lifetime=0x00000001,
-                         type='PSA_KEY_TYPE_RAW_DATA', bits=8,
-                         usage=usage, alg=alg, alg2=0,
-                         material=b'K',
-                         description='alg: ' + descr)
-        yield StorageKey(version=self.version,
-                         id=1, lifetime=0x00000001,
-                         type='PSA_KEY_TYPE_RAW_DATA', bits=8,
-                         usage=usage, alg=0, alg2=alg,
-                         material=b'L',
-                         description='alg2: ' + descr)
+        key1 = StorageTestData(version=self.version,
+                               id=1, lifetime=0x00000001,
+                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
+                               usage=usage, alg=alg, alg2=0,
+                               material=b'K',
+                               description='alg: ' + descr)
+        yield key1
+        key2 = StorageTestData(version=self.version,
+                               id=1, lifetime=0x00000001,
+                               type='PSA_KEY_TYPE_RAW_DATA', bits=8,
+                               usage=usage, alg=0, alg2=alg,
+                               material=b'L',
+                               description='alg2: ' + descr)
+        yield key2
 
-    def all_keys_for_algorithms(self) -> Iterator[StorageKey]:
+    def all_keys_for_algorithms(self) -> Iterator[StorageTestData]:
         """Generate test keys covering algorithm encodings."""
         algorithms = sorted(self.constructors.algorithms)
         for alg in self.constructors.generate_expressions(algorithms):
             yield from self.keys_for_algorithm(alg)
 
-    def generate_all_keys(self) -> List[StorageKey]:
+    def generate_all_keys(self) -> List[StorageTestData]:
         """Generate all keys for the test cases."""
-        keys = [] #type: List[StorageKey]
+        keys = [] #type: List[StorageTestData]
         keys += self.all_keys_for_lifetimes()
         keys += self.all_keys_for_usage_flags()
         keys += self.all_keys_for_types()
@@ -485,7 +517,7 @@
     def __init__(self, info: Information) -> None:
         super().__init__(info, 0, False)
 
-    def all_keys_for_usage_flags(self) -> Iterator[StorageKey]:
+    def all_keys_for_usage_flags(self) -> Iterator[StorageTestData]:
         """Generate test keys covering usage flags."""
         yield from self.generate_keys_for_usage_flags(test_implicit_usage=True)
         yield from self.generate_key_for_all_usage_flags()
@@ -495,7 +527,7 @@
             implyer_usage: str,
             alg: str,
             key_type: crypto_knowledge.KeyType
-    ) -> StorageKey:
+    ) -> StorageTestData:
         # pylint: disable=too-many-locals
         """Generate test keys for the specified implicit usage flag,
            algorithm and key type combination.
@@ -515,15 +547,16 @@
                                      key_type.expression)
         description = 'implied by {}: {} {} {}-bit'.format(
             usage_expression, alg_expression, key_type_expression, bits)
-        return StorageKey(version=self.version,
-                          id=1, lifetime=0x00000001,
-                          type=key_type.expression, bits=bits,
-                          usage=material_usage_flags,
-                          expected_usage=expected_usage_flags,
-                          alg=alg, alg2=alg2,
-                          material=key_material,
-                          description=description,
-                          implicit_usage=False)
+        key = StorageTestData(version=self.version,
+                              id=1, lifetime=0x00000001,
+                              type=key_type.expression, bits=bits,
+                              usage=material_usage_flags,
+                              expected_usage=expected_usage_flags,
+                              without_implicit_usage=True,
+                              alg=alg, alg2=alg2,
+                              material=key_material,
+                              description=description)
+        return key
 
     def gather_key_types_for_sign_alg(self) -> Dict[str, List[str]]:
         # pylint: disable=too-many-locals
@@ -573,7 +606,7 @@
                         alg_with_keys[alg] = [key_type]
         return alg_with_keys
 
-    def all_keys_for_implicit_usage(self) -> Iterator[StorageKey]:
+    def all_keys_for_implicit_usage(self) -> Iterator[StorageTestData]:
         """Generate test keys for usage flag extensions."""
         # Generate a key type and algorithm pair for each extendable usage
         # flag to generate a valid key for exercising. The key is generated
@@ -588,7 +621,7 @@
                     if kt.is_valid_for_signature(usage):
                         yield self.keys_for_implicit_usage(usage, alg, kt)
 
-    def generate_all_keys(self) -> List[StorageKey]:
+    def generate_all_keys(self) -> List[StorageTestData]:
         keys = super().generate_all_keys()
         keys += self.all_keys_for_implicit_usage()
         return keys