psa: Add RSA sign/verify hash support to the transparent test driver

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
diff --git a/library/psa_crypto_rsa.c b/library/psa_crypto_rsa.c
index 2886146..3e95d3a 100644
--- a/library/psa_crypto_rsa.c
+++ b/library/psa_crypto_rsa.c
@@ -52,10 +52,24 @@
 #define BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY  1
 #endif
 
+#if ( defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||  \
+      ( defined(PSA_CRYPTO_DRIVER_TEST) &&                   \
+        defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PKCS1V15_SIGN) &&  \
+        defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PKCS1_V15) ) )
+#define BUILTIN_ALG_RSA_PKCS1V15_SIGN  1
+#endif
+
+#if ( defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) ||  \
+      ( defined(PSA_CRYPTO_DRIVER_TEST) &&         \
+        defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PSS) &&  \
+        defined(MBEDTLS_RSA_C) && defined(MBEDTLS_PKCS1_V21) ) )
+#define BUILTIN_ALG_RSA_PSS 1
+#endif
+
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) || \
-    defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
     defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) || \
-    defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) || \
+    defined(BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
+    defined(BUILTIN_ALG_RSA_PSS) || \
     defined(BUILTIN_KEY_TYPE_RSA_KEY_PAIR) || \
     defined(BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
 
@@ -131,9 +145,9 @@
     return( status );
 }
 #endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_CRYPT) ||
-        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
         * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_OAEP) ||
-        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) ||
+        * defined(BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
+        * defined(BUILTIN_ALG_RSA_PSS) ||
         * defined(BUILTIN_KEY_TYPE_RSA_KEY_PAIR) ||
         * defined(BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY) */
 
@@ -323,8 +337,7 @@
 /* Sign/verify hashes */
 /****************************************************************/
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
-    defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
+#if defined(BUILTIN_ALG_RSA_PKCS1V15_SIGN) || defined(BUILTIN_ALG_RSA_PSS)
 
 /* Decode the hash algorithm from alg and store the mbedtls encoding in
  * md_alg. Verify that the hash length is acceptable. */
@@ -344,7 +357,7 @@
         return( PSA_ERROR_INVALID_ARGUMENT );
 #endif
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
+#if defined(BUILTIN_ALG_RSA_PKCS1V15_SIGN)
     /* For PKCS#1 v1.5 signature, if using a hash, the hash length
      * must be correct. */
     if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) &&
@@ -355,21 +368,21 @@
         if( mbedtls_md_get_size( md_info ) != hash_length )
             return( PSA_ERROR_INVALID_ARGUMENT );
     }
-#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
+#endif /* BUILTIN_ALG_RSA_PKCS1V15_SIGN */
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
+#if defined(BUILTIN_ALG_RSA_PSS)
     /* PSS requires a hash internally. */
     if( PSA_ALG_IS_RSA_PSS( alg ) )
     {
         if( md_info == NULL )
             return( PSA_ERROR_NOT_SUPPORTED );
     }
-#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
+#endif /* BUILTIN_ALG_RSA_PSS */
 
     return( PSA_SUCCESS );
 }
 
-psa_status_t mbedtls_psa_rsa_sign_hash(
+static psa_status_t rsa_sign_hash(
     const psa_key_attributes_t *attributes,
     const uint8_t *key_buffer, size_t key_buffer_size,
     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
@@ -397,7 +410,7 @@
         goto exit;
     }
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
+#if defined(BUILTIN_ALG_RSA_PKCS1V15_SIGN)
     if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) )
     {
         mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V15,
@@ -412,8 +425,8 @@
                                       signature );
     }
     else
-#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
+#endif /* BUILTIN_ALG_RSA_PKCS1V15_SIGN */
+#if defined(BUILTIN_ALG_RSA_PSS)
     if( PSA_ALG_IS_RSA_PSS( alg ) )
     {
         mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
@@ -427,7 +440,7 @@
                                            signature );
     }
     else
-#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
+#endif /* BUILTIN_ALG_RSA_PSS */
     {
         status = PSA_ERROR_INVALID_ARGUMENT;
         goto exit;
@@ -444,7 +457,7 @@
     return( status );
 }
 
-psa_status_t mbedtls_psa_rsa_verify_hash(
+static psa_status_t rsa_verify_hash(
     const psa_key_attributes_t *attributes,
     const uint8_t *key_buffer, size_t key_buffer_size,
     psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
@@ -472,7 +485,7 @@
         goto exit;
     }
 
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN)
+#if defined(BUILTIN_ALG_RSA_PKCS1V15_SIGN)
     if( PSA_ALG_IS_RSA_PKCS1V15_SIGN( alg ) )
     {
         mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V15,
@@ -487,8 +500,8 @@
                                         signature );
     }
     else
-#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN */
-#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
+#endif /* BUILTIN_ALG_RSA_PKCS1V15_SIGN */
+#if defined(BUILTIN_ALG_RSA_PSS)
     if( PSA_ALG_IS_RSA_PSS( alg ) )
     {
         mbedtls_rsa_set_padding( rsa, MBEDTLS_RSA_PKCS_V21, md_alg );
@@ -502,7 +515,7 @@
                                              signature );
     }
     else
-#endif /* MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS */
+#endif /* BUILTIN_ALG_RSA_PSS */
     {
         status = PSA_ERROR_INVALID_ARGUMENT;
         goto exit;
@@ -522,8 +535,8 @@
     return( status );
 }
 
-#endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
-        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) */
+#endif /* defined(BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
+        * defined(BUILTIN_ALG_RSA_PSS) */
 
 #if defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) || \
     defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_PUBLIC_KEY)
@@ -561,6 +574,36 @@
 }
 #endif /* defined(MBEDTLS_PSA_BUILTIN_KEY_TYPE_RSA_KEY_PAIR) */
 
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) || \
+    defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS)
+psa_status_t mbedtls_psa_rsa_sign_hash(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *key_buffer, size_t key_buffer_size,
+    psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
+    uint8_t *signature, size_t signature_size, size_t *signature_length )
+{
+    return( rsa_sign_hash(
+                attributes,
+                key_buffer, key_buffer_size,
+                alg, hash, hash_length,
+                signature, signature_size, signature_length ) );
+}
+
+psa_status_t mbedtls_psa_rsa_verify_hash(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *key_buffer, size_t key_buffer_size,
+    psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
+    const uint8_t *signature, size_t signature_length )
+{
+    return( rsa_verify_hash(
+                attributes,
+                key_buffer, key_buffer_size,
+                alg, hash, hash_length,
+                signature, signature_length ) );
+}
+#endif /* defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PKCS1V15_SIGN) ||
+        * defined(MBEDTLS_PSA_BUILTIN_ALG_RSA_PSS) */
+
 /*
  * BEYOND THIS POINT, TEST DRIVER ENTRY POINTS ONLY.
  */
@@ -603,6 +646,63 @@
 }
 #endif /* defined(MBEDTLS_PSA_ACCEL_KEY_TYPE_RSA_KEY_PAIR) */
 
+#if defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PKCS1V15_SIGN) || \
+    defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PSS)
+psa_status_t mbedtls_transparent_test_driver_rsa_sign_hash(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *key_buffer, size_t key_buffer_size,
+    psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
+    uint8_t *signature, size_t signature_size, size_t *signature_length )
+{
+#if defined(MBEDTLS_RSA_C) && \
+    (defined(MBEDTLS_PKCS1_V15) || defined(MBEDTLS_PKCS1_V21))
+    return( rsa_sign_hash(
+                attributes,
+                key_buffer, key_buffer_size,
+                alg, hash, hash_length,
+                signature, signature_size, signature_length ) );
+#else
+    (void)attributes;
+    (void)key_buffer;
+    (void)key_buffer_size;
+    (void)alg;
+    (void)hash;
+    (void)hash_length;
+    (void)signature;
+    (void)signature_size;
+    (void)signature_length;
+    return( PSA_ERROR_NOT_SUPPORTED );
+#endif
+}
+
+psa_status_t mbedtls_transparent_test_driver_rsa_verify_hash(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *key_buffer, size_t key_buffer_size,
+    psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
+    const uint8_t *signature, size_t signature_length )
+{
+#if defined(MBEDTLS_RSA_C) && \
+    (defined(MBEDTLS_PKCS1_V15) || defined(MBEDTLS_PKCS1_V21))
+    return( rsa_verify_hash(
+                attributes,
+                key_buffer, key_buffer_size,
+                alg, hash, hash_length,
+                signature, signature_length ) );
+#else
+    (void)attributes;
+    (void)key_buffer;
+    (void)key_buffer_size;
+    (void)alg;
+    (void)hash;
+    (void)hash_length;
+    (void)signature;
+    (void)signature_length;
+    return( PSA_ERROR_NOT_SUPPORTED );
+#endif
+}
+#endif /* defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PKCS1V15_SIGN) ||
+        * defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PSS) */
+
 #endif /* PSA_CRYPTO_DRIVER_TEST */
 
 #endif /* MBEDTLS_PSA_CRYPTO_C */
diff --git a/library/psa_crypto_rsa.h b/library/psa_crypto_rsa.h
index bbf4e7d..41a90f7 100644
--- a/library/psa_crypto_rsa.h
+++ b/library/psa_crypto_rsa.h
@@ -233,6 +233,18 @@
     const psa_key_attributes_t *attributes,
     uint8_t *key, size_t key_size, size_t *key_length );
 
+psa_status_t mbedtls_transparent_test_driver_rsa_sign_hash(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *key_buffer, size_t key_buffer_size,
+    psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
+    uint8_t *signature, size_t signature_size, size_t *signature_length );
+
+psa_status_t mbedtls_transparent_test_driver_rsa_verify_hash(
+    const psa_key_attributes_t *attributes,
+    const uint8_t *key_buffer, size_t key_buffer_size,
+    psa_algorithm_t alg, const uint8_t *hash, size_t hash_length,
+    const uint8_t *signature, size_t signature_length );
+
 #endif /* PSA_CRYPTO_DRIVER_TEST */
 
 #endif /* PSA_CRYPTO_RSA_H */
diff --git a/tests/src/drivers/signature.c b/tests/src/drivers/signature.c
index cea0351..78b7ff9 100644
--- a/tests/src/drivers/signature.c
+++ b/tests/src/drivers/signature.c
@@ -28,6 +28,7 @@
 #if defined(MBEDTLS_PSA_CRYPTO_DRIVERS) && defined(PSA_CRYPTO_DRIVER_TEST)
 #include "psa/crypto.h"
 #include "psa_crypto_core.h"
+#include "psa_crypto_rsa.h"
 #include "mbedtls/ecp.h"
 
 #include "test/drivers/signature.h"
@@ -66,6 +67,19 @@
 
     psa_status_t status = PSA_ERROR_NOT_SUPPORTED;
 
+#if defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PKCS1V15_SIGN) || \
+    defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PSS)
+    if( attributes->core.type == PSA_KEY_TYPE_RSA_KEY_PAIR )
+    {
+        return( mbedtls_transparent_test_driver_rsa_sign_hash(
+                    attributes,
+                    key_buffer, key_buffer_size,
+                    alg, hash, hash_length,
+                    signature, signature_size, signature_length ) );
+    }
+#endif /* defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PKCS1V15_SIGN) ||
+        * defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PSS) */
+
 #if defined(MBEDTLS_ECDSA_C) && defined(MBEDTLS_ECDSA_DETERMINISTIC) && \
     defined(MBEDTLS_SHA256_C)
     if( alg != PSA_ALG_DETERMINISTIC_ECDSA( PSA_ALG_SHA_256 ) )
@@ -178,6 +192,19 @@
 
     psa_status_t status = PSA_ERROR_NOT_SUPPORTED;
 
+#if defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PKCS1V15_SIGN) || \
+    defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PSS)
+    if( PSA_KEY_TYPE_IS_RSA( attributes->core.type ) )
+    {
+        return( mbedtls_transparent_test_driver_rsa_verify_hash(
+                    attributes,
+                    key_buffer, key_buffer_size,
+                    alg, hash, hash_length,
+                    signature, signature_length ) );
+    }
+#endif /* defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PKCS1V15_SIGN) ||
+        * defined(MBEDTLS_PSA_ACCEL_ALG_RSA_PSS) */
+
 #if defined(MBEDTLS_ECDSA_C) && defined(MBEDTLS_ECDSA_DETERMINISTIC) && \
     defined(MBEDTLS_SHA256_C)
     if( alg != PSA_ALG_DETERMINISTIC_ECDSA( PSA_ALG_SHA_256 ) )