Move key buffer size computation out of psa_generate_key_internal()

Preparatory commit to eventually change
psa_generate_key_internal() signature to that of
a PSA driver generate_key entry point.

To be able to change the signature, the buffer to
store the key has to be allocated before the call
to psa_generate_key_internal() thus its size has
to be calculed beforehand as well.

This is the purpose of this commit: to move the
computation of the key size in bytes out of
psa_generate_key_internal().

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index b673d59..ab50d53 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -5984,24 +5984,83 @@
 }
 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) */
 
-static psa_status_t psa_generate_key_internal(
-    psa_key_slot_t *slot, size_t bits,
-    const uint8_t *domain_parameters, size_t domain_parameters_size )
+/** Get the key buffer size for the key material in export format
+ *
+ * \param[in] type  The key type
+ * \param[in] bits  The number of bits of the key
+ * \param[out] key_buffer_size  Minimum buffer size to contain the key material
+ *                              in export format
+ *
+ * \retval #PSA_SUCCESS
+ *         The minimum size for a buffer to contain the key material in export
+ *         format has been returned successfully.
+ * \retval #PSA_ERROR_INVALID_ARGUMENT
+ *         The size in bits of the key is not valid.
+ * \retval #PSA_ERROR_NOT_SUPPORTED
+ *         The type and/or the size in bits of the key or the combination of
+ *         the two is not supported.
+ */
+static psa_status_t psa_get_key_buffer_size(
+    psa_key_type_t type, size_t bits, size_t *key_buffer_size )
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
-    psa_key_type_t type = slot->attr.type;
-
-    if( domain_parameters == NULL && domain_parameters_size != 0 )
-        return( PSA_ERROR_INVALID_ARGUMENT );
 
     if( key_type_is_raw_bytes( type ) )
     {
         status = validate_unstructured_key_bit_size( type, bits );
         if( status != PSA_SUCCESS )
             return( status );
+        *key_buffer_size = PSA_BITS_TO_BYTES( bits );
+    }
+    else
+#if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR)
+    if( PSA_KEY_TYPE_IS_RSA( type ) && PSA_KEY_TYPE_IS_KEY_PAIR( type ) )
+    {
+        if( bits > PSA_VENDOR_RSA_MAX_KEY_BITS )
+            return( PSA_ERROR_NOT_SUPPORTED );
 
-        /* Allocate memory for the key */
-        status = psa_allocate_buffer_to_slot( slot, PSA_BITS_TO_BYTES( bits ) );
+        /* Accept only byte-aligned keys, for the same reasons as
+         * in psa_import_rsa_key(). */
+        if( bits % 8 != 0 )
+            return( PSA_ERROR_NOT_SUPPORTED );
+
+        *key_buffer_size = PSA_KEY_EXPORT_RSA_KEY_PAIR_MAX_SIZE( bits );
+    }
+    else
+#endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) */
+
+#if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_ECC_KEY_PAIR)
+    if( PSA_KEY_TYPE_IS_ECC( type ) && PSA_KEY_TYPE_IS_KEY_PAIR( type ) )
+    {
+        *key_buffer_size = PSA_BITS_TO_BYTES( bits );
+    }
+    else
+#endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_ECC_KEY_PAIR) */
+    {
+        return( PSA_ERROR_NOT_SUPPORTED );
+    }
+
+    return( PSA_SUCCESS );
+}
+
+static psa_status_t psa_generate_key_internal(
+    psa_key_slot_t *slot, size_t bits,
+    const uint8_t *domain_parameters, size_t domain_parameters_size )
+{
+    psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
+    psa_key_type_t type = slot->attr.type;
+    size_t key_buffer_size;
+
+    if( domain_parameters == NULL && domain_parameters_size != 0 )
+        return( PSA_ERROR_INVALID_ARGUMENT );
+
+    status = psa_get_key_buffer_size( slot->attr.type, bits, &key_buffer_size );
+    if( status != PSA_SUCCESS )
+        return( status );
+
+    if( key_type_is_raw_bytes( type ) )
+    {
+        status = psa_allocate_buffer_to_slot( slot, key_buffer_size );
         if( status != PSA_SUCCESS )
             return( status );
 
@@ -6024,12 +6083,7 @@
         mbedtls_rsa_context rsa;
         int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
         int exponent;
-        if( bits > PSA_VENDOR_RSA_MAX_KEY_BITS )
-            return( PSA_ERROR_NOT_SUPPORTED );
-        /* Accept only byte-aligned keys, for the same reasons as
-         * in mbedtls_psa_rsa_import_key(). */
-        if( bits % 8 != 0 )
-            return( PSA_ERROR_NOT_SUPPORTED );
+
         status = psa_read_rsa_exponent( domain_parameters,
                                         domain_parameters_size,
                                         &exponent );
@@ -6044,10 +6098,7 @@
         if( ret != 0 )
             return( mbedtls_to_psa_error( ret ) );
 
-        /* Make sure to always have an export representation available */
-        size_t bytes = PSA_KEY_EXPORT_RSA_KEY_PAIR_MAX_SIZE( bits );
-
-        status = psa_allocate_buffer_to_slot( slot, bytes );
+        status = psa_allocate_buffer_to_slot( slot, key_buffer_size );
         if( status != PSA_SUCCESS )
         {
             mbedtls_rsa_free( &rsa );
@@ -6057,7 +6108,7 @@
         status = mbedtls_psa_rsa_export_key( type,
                                              &rsa,
                                              slot->key.data,
-                                             bytes,
+                                             slot->key.bytes,
                                              &slot->key.bytes );
         mbedtls_rsa_free( &rsa );
         if( status != PSA_SUCCESS )
@@ -6093,8 +6144,7 @@
 
 
         /* Make sure to always have an export representation available */
-        size_t bytes = PSA_BITS_TO_BYTES( bits );
-        status = psa_allocate_buffer_to_slot( slot, bytes );
+        status = psa_allocate_buffer_to_slot( slot, key_buffer_size );
         if( status != PSA_SUCCESS )
         {
             mbedtls_ecp_keypair_free( &ecp );
@@ -6102,11 +6152,11 @@
         }
 
         status = mbedtls_to_psa_error(
-            mbedtls_ecp_write_key( &ecp, slot->key.data, bytes ) );
+            mbedtls_ecp_write_key( &ecp, slot->key.data, slot->key.bytes ) );
 
         mbedtls_ecp_keypair_free( &ecp );
         if( status != PSA_SUCCESS ) {
-            memset( slot->key.data, 0, bytes );
+            memset( slot->key.data, 0, slot->key.bytes );
             psa_remove_key_data_from_memory( slot );
         }
         return( status );