psa_constant_names: support key agreement algorithms
diff --git a/scripts/generate_psa_constants.py b/scripts/generate_psa_constants.py
index 382fd23..dac6003 100755
--- a/scripts/generate_psa_constants.py
+++ b/scripts/generate_psa_constants.py
@@ -30,6 +30,14 @@
     }
 }
 
+static const char *psa_ka_algorithm_name(psa_algorithm_t ka_alg)
+{
+    switch (ka_alg) {
+    %(ka_algorithm_cases)s
+    default: return NULL;
+    }
+}
+
 static int psa_snprint_key_type(char *buffer, size_t buffer_size,
                                 psa_key_type_t type)
 {
@@ -47,12 +55,13 @@
     return (int) required_size;
 }
 
+#define NO_LENGTH_MODIFIER 0xfffffffflu
 static int psa_snprint_algorithm(char *buffer, size_t buffer_size,
                                  psa_algorithm_t alg)
 {
     size_t required_size = 0;
     psa_algorithm_t core_alg = alg;
-    unsigned long length_modifier = 0;
+    unsigned long length_modifier = NO_LENGTH_MODIFIER;
     if (PSA_ALG_IS_MAC(alg)) {
         core_alg = PSA_ALG_TRUNCATED_MAC(alg, 0);
         if (core_alg != alg) {
@@ -70,6 +79,15 @@
                    "PSA_ALG_AEAD_WITH_TAG_LENGTH(", 29);
             length_modifier = PSA_AEAD_TAG_LENGTH(alg);
         }
+    } else if (PSA_ALG_IS_KEY_AGREEMENT(alg) &&
+               !PSA_ALG_IS_RAW_KEY_AGREEMENT(alg)) {
+        core_alg = PSA_ALG_KEY_AGREEMENT_GET_KDF(alg);
+        append(&buffer, buffer_size, &required_size,
+               "PSA_ALG_KEY_AGREEMENT(", 22);
+        append_with_alg(&buffer, buffer_size, &required_size,
+                        psa_ka_algorithm_name,
+                        PSA_ALG_KEY_AGREEMENT_GET_BASE(alg));
+        append(&buffer, buffer_size, &required_size, ", ", 2);
     }
     switch (core_alg) {
     %(algorithm_cases)s
@@ -81,9 +99,11 @@
         break;
     }
     if (core_alg != alg) {
-        append(&buffer, buffer_size, &required_size, ", ", 2);
-        append_integer(&buffer, buffer_size, &required_size,
-                       "%%lu", length_modifier);
+        if (length_modifier != NO_LENGTH_MODIFIER) {
+            append(&buffer, buffer_size, &required_size, ", ", 2);
+            append_integer(&buffer, buffer_size, &required_size,
+                           "%%lu", length_modifier);
+        }
         append(&buffer, buffer_size, &required_size, ")", 1);
     }
     buffer[0] = 0;
@@ -126,9 +146,12 @@
         } else '''
 
 algorithm_from_hash_template = '''if (%(tester)s(core_alg)) {
-            append_with_hash(&buffer, buffer_size, &required_size,
-                             "%(builder)s", %(builder_length)s,
-                             PSA_ALG_GET_HASH(core_alg));
+            append(&buffer, buffer_size, &required_size,
+                   "%(builder)s(", %(builder_length)s + 1);
+            append_with_alg(&buffer, buffer_size, &required_size,
+                            psa_hash_algorithm_name,
+                            PSA_ALG_GET_HASH(core_alg));
+            append(&buffer, buffer_size, &required_size, ")", 1);
         } else '''
 
 bit_test_template = '''\
@@ -149,6 +172,7 @@
         self.ecc_curves = set()
         self.algorithms = set()
         self.hash_algorithms = set()
+        self.ka_algorithms = set()
         self.algorithms_from_hash = {}
         self.key_usages = set()
 
@@ -193,6 +217,9 @@
             # Ad hoc detection of hash algorithms
             if re.search(r'0x010000[0-9A-Fa-f]{2}', definition):
                 self.hash_algorithms.add(name)
+            # Ad hoc detection of key agreement algorithms
+            if re.search(r'0x30[0-9A-Fa-f]{2}0000', definition):
+                self.ka_algorithms.add(name)
         elif name.startswith('PSA_ALG_') and parameter == 'hash_alg':
             if name in ['PSA_ALG_DSA', 'PSA_ALG_ECDSA']:
                 # A naming irregularity
@@ -256,6 +283,10 @@
         return '\n    '.join(map(self.make_return_case,
                                  sorted(self.hash_algorithms)))
 
+    def make_ka_algorithm_cases(self):
+        return '\n    '.join(map(self.make_return_case,
+                                 sorted(self.ka_algorithms)))
+
     def make_algorithm_cases(self):
         return '\n    '.join(map(self.make_append_case,
                                  sorted(self.algorithms)))
@@ -281,6 +312,7 @@
         data['key_type_cases'] = self.make_key_type_cases()
         data['key_type_code'] = self.make_key_type_code()
         data['hash_algorithm_cases'] = self.make_hash_algorithm_cases()
+        data['ka_algorithm_cases'] = self.make_ka_algorithm_cases()
         data['algorithm_cases'] = self.make_algorithm_cases()
         data['algorithm_code'] = self.make_algorithm_code()
         data['key_usage_code'] = self.make_key_usage_code()