Merge pull request #5779 from mprse/rsa_decr_1b

RSA decrypt 1b: prepare for TLS testing
diff --git a/programs/ssl/ssl_client2.c b/programs/ssl/ssl_client2.c
index f741d99..2cfdde6 100644
--- a/programs/ssl/ssl_client2.c
+++ b/programs/ssl/ssl_client2.c
@@ -115,6 +115,7 @@
 #define DFL_USE_SRTP            0
 #define DFL_SRTP_FORCE_PROFILE  0
 #define DFL_SRTP_MKI            ""
+#define DFL_KEY_OPAQUE_ALG      "none"
 
 #define GET_REQUEST "GET %s HTTP/1.0\r\nExtra-header: "
 #define GET_REQUEST_END "\r\n\r\n"
@@ -343,6 +344,13 @@
 #define USAGE_SERIALIZATION ""
 #endif
 
+#define USAGE_KEY_OPAQUE_ALGS \
+    "    key_opaque_algs=%%s  Allowed opaque key algorithms.\n"                      \
+    "                        comma-separated pair of values among the following:\n"   \
+    "                        rsa-sign-pkcs1, rsa-sign-pss, rsa-decrypt,\n"           \
+    "                        ecdsa-sign, ecdh, none (only acceptable for\n"          \
+    "                        the second value).\n"                                   \
+
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
 #define USAGE_TLS1_3_KEY_EXCHANGE_MODES \
     "    tls13_kex_modes=%%s   default: all\n"     \
@@ -411,6 +419,7 @@
     USAGE_CURVES                                            \
     USAGE_SIG_ALGS                                          \
     USAGE_DHMLEN                                            \
+    USAGE_KEY_OPAQUE_ALGS                                   \
     "\n"
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
@@ -523,6 +532,8 @@
     int use_srtp;               /* Support SRTP                             */
     int force_srtp_profile;     /* SRTP protection profile to use or all    */
     const char *mki;            /* The dtls mki value to use                */
+    const char *key_opaque_alg1; /* Allowed opaque key alg 1                */
+    const char *key_opaque_alg2; /* Allowed Opaque key alg 2                */
 } opt;
 
 #include "ssl_test_common_source.c"
@@ -885,6 +896,8 @@
     opt.use_srtp            = DFL_USE_SRTP;
     opt.force_srtp_profile  = DFL_SRTP_FORCE_PROFILE;
     opt.mki                 = DFL_SRTP_MKI;
+    opt.key_opaque_alg1     = DFL_KEY_OPAQUE_ALG;
+    opt.key_opaque_alg2     = DFL_KEY_OPAQUE_ALG;
 
     for( i = 1; i < argc; i++ )
     {
@@ -1308,6 +1321,12 @@
         {
             opt.mki = q;
         }
+        else if( strcmp( p, "key_opaque_algs" ) == 0 )
+        {
+            if( key_opaque_alg_parse( q, &opt.key_opaque_alg1,
+                                         &opt.key_opaque_alg2 ) != 0 )
+                goto usage;
+        }
         else
             goto usage;
     }
@@ -1698,26 +1717,24 @@
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( opt.key_opaque != 0 )
     {
-        psa_algorithm_t psa_alg, psa_alg2;
+        psa_algorithm_t psa_alg, psa_alg2 = PSA_ALG_NONE;
+        psa_key_usage_t usage = 0;
 
-        if( mbedtls_pk_get_type( &pkey ) == MBEDTLS_PK_ECKEY )
+        if( key_opaque_set_alg_usage( opt.key_opaque_alg1,
+                                      opt.key_opaque_alg2,
+                                      &psa_alg, &psa_alg2,
+                                      &usage,
+                                      mbedtls_pk_get_type( &pkey ) ) == 0 )
         {
-            psa_alg = PSA_ALG_ECDSA( PSA_ALG_ANY_HASH );
-            psa_alg2 = PSA_ALG_NONE;
-        }
-        else
-        {
-            psa_alg = PSA_ALG_RSA_PKCS1V15_SIGN( PSA_ALG_ANY_HASH );
-            psa_alg2 = PSA_ALG_RSA_PSS( PSA_ALG_ANY_HASH );
-        }
-
-        if( ( ret = mbedtls_pk_wrap_as_opaque( &pkey, &key_slot, psa_alg,
-                                               PSA_KEY_USAGE_SIGN_HASH,
-                                               psa_alg2 ) ) != 0 )
-        {
-            mbedtls_printf( " failed\n  !  "
-                            "mbedtls_pk_wrap_as_opaque returned -0x%x\n\n", (unsigned int)  -ret );
-            goto exit;
+            ret = mbedtls_pk_wrap_as_opaque( &pkey, &key_slot, psa_alg,
+                                             usage, psa_alg2 );
+            if( ret != 0 )
+            {
+                mbedtls_printf( " failed\n  !  "
+                                "mbedtls_pk_wrap_as_opaque returned -0x%x\n\n",
+                                (unsigned int)  -ret );
+                goto exit;
+            }
         }
     }
 #endif /* MBEDTLS_USE_PSA_CRYPTO */
diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c
index d728b95..0047cab 100644
--- a/programs/ssl/ssl_server2.c
+++ b/programs/ssl/ssl_server2.c
@@ -151,6 +151,7 @@
 #define DFL_USE_SRTP            0
 #define DFL_SRTP_FORCE_PROFILE  0
 #define DFL_SRTP_SUPPORT_MKI    0
+#define DFL_KEY_OPAQUE_ALG      "none"
 
 #define LONG_RESPONSE "<p>01-blah-blah-blah-blah-blah-blah-blah-blah-blah\r\n" \
     "02-blah-blah-blah-blah-blah-blah-blah-blah-blah-blah-blah-blah-blah\r\n"  \
@@ -455,6 +456,17 @@
 #define USAGE_SERIALIZATION ""
 #endif
 
+#define USAGE_KEY_OPAQUE_ALGS \
+    "    key_opaque_algs=%%s  Allowed opaque key 1 algorithms.\n"                    \
+    "                        comma-separated pair of values among the following:\n"  \
+    "                        rsa-sign-pkcs1, rsa-sign-pss, rsa-decrypt,\n"           \
+    "                        ecdsa-sign, ecdh, none (only acceptable for\n"          \
+    "                        the second value).\n"                                   \
+    "    key_opaque_algs2=%%s Allowed opaque key 2 algorithms.\n"                    \
+    "                        comma-separated pair of values among the following:\n"  \
+    "                        rsa-sign-pkcs1, rsa-sign-pss, rsa-decrypt,\n"           \
+    "                        ecdsa-sign, ecdh, none (only acceptable for\n"          \
+    "                        the second value).\n"
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
 #define USAGE_TLS1_3_KEY_EXCHANGE_MODES \
     "    tls13_kex_modes=%%s   default: all\n"     \
@@ -519,6 +531,7 @@
     USAGE_ETM                                               \
     USAGE_CURVES                                            \
     USAGE_SIG_ALGS                                          \
+    USAGE_KEY_OPAQUE_ALGS                                   \
     "\n"
 
 #if defined(MBEDTLS_SSL_PROTO_TLS1_3)
@@ -659,6 +672,10 @@
     int use_srtp;               /* Support SRTP                             */
     int force_srtp_profile;     /* SRTP protection profile to use or all    */
     int support_mki;            /* The dtls mki mki support                 */
+    const char *key1_opaque_alg1; /* Allowed opaque key 1 alg 1            */
+    const char *key1_opaque_alg2; /* Allowed Opaque key 1 alg 2            */
+    const char *key2_opaque_alg1; /* Allowed opaque key 2 alg 1            */
+    const char *key2_opaque_alg2; /* Allowed Opaque key 2 alg 2            */
 } opt;
 
 #include "ssl_test_common_source.c"
@@ -679,7 +696,7 @@
 }
 
 /*
- * Used by sni_parse and psk_parse to handle coma-separated lists
+ * Used by sni_parse and psk_parse to handle comma-separated lists
  */
 #define GET_ITEM( dst )         \
     do                          \
@@ -1615,6 +1632,10 @@
     opt.use_srtp            = DFL_USE_SRTP;
     opt.force_srtp_profile  = DFL_SRTP_FORCE_PROFILE;
     opt.support_mki         = DFL_SRTP_SUPPORT_MKI;
+    opt.key1_opaque_alg1   = DFL_KEY_OPAQUE_ALG;
+    opt.key1_opaque_alg2   = DFL_KEY_OPAQUE_ALG;
+    opt.key2_opaque_alg1   = DFL_KEY_OPAQUE_ALG;
+    opt.key2_opaque_alg2   = DFL_KEY_OPAQUE_ALG;
 
     for( i = 1; i < argc; i++ )
     {
@@ -2088,6 +2109,18 @@
         {
             opt.support_mki = atoi( q );
         }
+        else if( strcmp( p, "key_opaque_algs" ) == 0 )
+        {
+            if( key_opaque_alg_parse( q, &opt.key1_opaque_alg1,
+                                         &opt.key1_opaque_alg2 ) != 0 )
+                goto usage;
+        }
+        else if( strcmp( p, "key_opaque_algs2" ) == 0 )
+        {
+            if( key_opaque_alg_parse( q, &opt.key2_opaque_alg1,
+                                         &opt.key2_opaque_alg2 ) != 0 )
+                goto usage;
+        }
         else
             goto usage;
     }
@@ -2564,59 +2597,44 @@
 #if defined(MBEDTLS_USE_PSA_CRYPTO)
     if( opt.key_opaque != 0 )
     {
-        psa_algorithm_t psa_alg, psa_alg2;
-        psa_key_usage_t psa_usage;
+        psa_algorithm_t psa_alg, psa_alg2 = PSA_ALG_NONE;
+        psa_key_usage_t psa_usage = 0;
 
-        if ( mbedtls_pk_get_type( &pkey ) == MBEDTLS_PK_ECKEY ||
-             mbedtls_pk_get_type( &pkey ) == MBEDTLS_PK_RSA )
+        if( key_opaque_set_alg_usage( opt.key1_opaque_alg1,
+                                      opt.key1_opaque_alg2,
+                                      &psa_alg, &psa_alg2,
+                                      &psa_usage,
+                                      mbedtls_pk_get_type( &pkey ) ) == 0 )
         {
-            if( mbedtls_pk_get_type( &pkey ) == MBEDTLS_PK_ECKEY )
-            {
-                psa_alg = PSA_ALG_ECDSA( PSA_ALG_ANY_HASH );
-                psa_alg2 = PSA_ALG_ECDH;
-                psa_usage = PSA_KEY_USAGE_SIGN_HASH | PSA_KEY_USAGE_DERIVE;
-            }
-            else
-            {
-                psa_alg = PSA_ALG_RSA_PKCS1V15_SIGN( PSA_ALG_ANY_HASH );
-                psa_alg2 = PSA_ALG_NONE;
-                psa_usage = PSA_KEY_USAGE_SIGN_HASH;
-            }
+            ret = mbedtls_pk_wrap_as_opaque( &pkey, &key_slot,
+                                             psa_alg, psa_usage, psa_alg2 );
 
-            if( ( ret = mbedtls_pk_wrap_as_opaque( &pkey, &key_slot,
-                                                   psa_alg,
-                                                   psa_usage,
-                                                   psa_alg2 ) ) != 0 )
+            if( ret != 0 )
             {
                 mbedtls_printf( " failed\n  !  "
-                                "mbedtls_pk_wrap_as_opaque returned -0x%x\n\n", (unsigned int)  -ret );
+                                "mbedtls_pk_wrap_as_opaque returned -0x%x\n\n",
+                                (unsigned int)  -ret );
                 goto exit;
             }
         }
 
-        if ( mbedtls_pk_get_type( &pkey2 ) == MBEDTLS_PK_ECKEY ||
-             mbedtls_pk_get_type( &pkey2 ) == MBEDTLS_PK_RSA )
-        {
-            if( mbedtls_pk_get_type( &pkey2 ) == MBEDTLS_PK_ECKEY )
-            {
-                psa_alg = PSA_ALG_ECDSA( PSA_ALG_ANY_HASH );
-                psa_alg2 = PSA_ALG_ECDH;
-                psa_usage = PSA_KEY_USAGE_SIGN_HASH | PSA_KEY_USAGE_DERIVE;
-            }
-            else
-            {
-                psa_alg = PSA_ALG_RSA_PKCS1V15_SIGN( PSA_ALG_ANY_HASH );
-                psa_alg2 = PSA_ALG_NONE;
-                psa_usage = PSA_KEY_USAGE_SIGN_HASH;
-            }
+        psa_alg = PSA_ALG_NONE; psa_alg2 = PSA_ALG_NONE;
+        psa_usage = 0;
 
-            if( ( ret = mbedtls_pk_wrap_as_opaque( &pkey2, &key_slot2,
-                                                   psa_alg,
-                                                   psa_usage,
-                                                   psa_alg2 ) ) != 0 )
+        if( key_opaque_set_alg_usage( opt.key2_opaque_alg1,
+                                      opt.key2_opaque_alg2,
+                                      &psa_alg, &psa_alg2,
+                                      &psa_usage,
+                                      mbedtls_pk_get_type( &pkey2 ) ) == 0 )
+        {
+            ret = mbedtls_pk_wrap_as_opaque( &pkey2, &key_slot2,
+                                             psa_alg, psa_usage, psa_alg2 );
+
+            if( ret != 0 )
             {
                 mbedtls_printf( " failed\n  !  "
-                                "mbedtls_pk_wrap_as_opaque returned -0x%x\n\n", (unsigned int)  -ret );
+                                "mbedtls_pk_wrap_as_opaque returned -0x%x\n\n",
+                                (unsigned int)  -ret );
                 goto exit;
             }
         }
diff --git a/programs/ssl/ssl_test_lib.c b/programs/ssl/ssl_test_lib.c
index a28a477..a7f3d0e 100644
--- a/programs/ssl/ssl_test_lib.c
+++ b/programs/ssl/ssl_test_lib.c
@@ -193,6 +193,103 @@
 #endif /* !MBEDTLS_TEST_USE_PSA_CRYPTO_RNG */
 }
 
+int key_opaque_alg_parse( const char *arg, const char **alg1, const char **alg2 )
+{
+    char* separator;
+    if( ( separator = strchr( arg, ',' ) ) == NULL )
+        return 1;
+    *separator = '\0';
+
+    *alg1 = arg;
+    *alg2 = separator + 1;
+
+    if( strcmp( *alg1, "rsa-sign-pkcs1" ) != 0 &&
+        strcmp( *alg1, "rsa-sign-pss" ) != 0 &&
+        strcmp( *alg1, "rsa-decrypt" ) != 0 &&
+        strcmp( *alg1, "ecdsa-sign" ) != 0 &&
+        strcmp( *alg1, "ecdh" ) != 0 )
+        return 1;
+
+    if( strcmp( *alg2, "rsa-sign-pkcs1" ) != 0 &&
+        strcmp( *alg2, "rsa-sign-pss" ) != 0 &&
+        strcmp( *alg2, "rsa-decrypt" ) != 0 &&
+        strcmp( *alg2, "ecdsa-sign" ) != 0 &&
+        strcmp( *alg2, "ecdh" ) != 0 &&
+        strcmp( *alg2, "none" ) != 0 )
+        return 1;
+
+    return 0;
+}
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+int key_opaque_set_alg_usage( const char *alg1, const char *alg2,
+                              psa_algorithm_t *psa_alg1,
+                              psa_algorithm_t *psa_alg2,
+                              psa_key_usage_t *usage,
+                              mbedtls_pk_type_t key_type )
+{
+    if( strcmp( alg1, "none" ) != 0 )
+    {
+        const char * algs[] = { alg1, alg2 };
+        psa_algorithm_t *psa_algs[] = { psa_alg1, psa_alg2 };
+
+        for ( int i = 0; i < 2; i++ )
+        {
+            if( strcmp( algs[i], "rsa-sign-pkcs1" ) == 0 )
+            {
+                *psa_algs[i] = PSA_ALG_RSA_PKCS1V15_SIGN( PSA_ALG_ANY_HASH );
+                *usage |= PSA_KEY_USAGE_SIGN_HASH;
+            }
+            else if( strcmp( algs[i], "rsa-sign-pss" ) == 0 )
+            {
+                *psa_algs[i] = PSA_ALG_RSA_PSS( PSA_ALG_ANY_HASH );
+                *usage |= PSA_KEY_USAGE_SIGN_HASH;
+            }
+            else if( strcmp( algs[i], "rsa-decrypt" ) == 0 )
+            {
+                *psa_algs[i] = PSA_ALG_RSA_PKCS1V15_CRYPT;
+                *usage |= PSA_KEY_USAGE_DECRYPT;
+            }
+            else if( strcmp( algs[i], "ecdsa-sign" ) == 0 )
+            {
+                *psa_algs[i] = PSA_ALG_ECDSA( PSA_ALG_ANY_HASH );
+                *usage |= PSA_KEY_USAGE_SIGN_HASH;
+            }
+            else if( strcmp( algs[i], "ecdh" ) == 0 )
+            {
+                *psa_algs[i] = PSA_ALG_ECDH;
+                *usage |= PSA_KEY_USAGE_DERIVE;
+            }
+            else if( strcmp( algs[i], "none" ) == 0 )
+            {
+                *psa_algs[i] = PSA_ALG_NONE;
+            }
+        }
+    }
+    else
+    {
+        if( key_type == MBEDTLS_PK_ECKEY )
+        {
+            *psa_alg1 = PSA_ALG_ECDSA( PSA_ALG_ANY_HASH );
+            *psa_alg2 = PSA_ALG_ECDH;
+            *usage = PSA_KEY_USAGE_SIGN_HASH | PSA_KEY_USAGE_DERIVE;
+        }
+        else if( key_type == MBEDTLS_PK_RSA )
+        {
+            *psa_alg1 = PSA_ALG_RSA_PKCS1V15_SIGN( PSA_ALG_ANY_HASH );
+            *psa_alg2 = PSA_ALG_RSA_PSS( PSA_ALG_ANY_HASH );
+            *usage = PSA_KEY_USAGE_SIGN_HASH;
+        }
+        else
+        {
+            return 1;
+        }
+    }
+
+    return 0;
+}
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
 #if defined(MBEDTLS_X509_TRUSTED_CERTIFICATE_CALLBACK)
 int ca_callback( void *data, mbedtls_x509_crt const *child,
                  mbedtls_x509_crt **candidates )
diff --git a/programs/ssl/ssl_test_lib.h b/programs/ssl/ssl_test_lib.h
index a359b3f..f0d0c3b 100644
--- a/programs/ssl/ssl_test_lib.h
+++ b/programs/ssl/ssl_test_lib.h
@@ -221,6 +221,48 @@
  */
 int rng_get( void *p_rng, unsigned char *output, size_t output_len );
 
+/** Parse command-line option: key_opaque_algs
+ *
+ *
+ * \param arg           String value of key_opaque_algs
+ *                      Coma-separated pair of values among the following:
+ *                      - "rsa-sign-pkcs1"
+ *                      - "rsa-sign-pss"
+ *                      - "rsa-decrypt"
+ *                      - "ecdsa-sign"
+ *                      - "ecdh"
+ *                      - "none" (only acceptable for the second value).
+ * \param alg1          Address of pointer to alg #1
+ * \param alg2          Address of pointer to alg #2
+ *
+ * \return              \c 0 on success.
+ * \return              \c 1 on parse failure.
+ */
+int key_opaque_alg_parse( const char *arg, const char **alg1, const char **alg2 );
+
+#if defined(MBEDTLS_USE_PSA_CRYPTO)
+/** Parse given opaque key algoritms to obtain psa algs and usage
+ *  that will be passed to mbedtls_pk_wrap_as_opaque().
+ *
+ *
+ * \param alg1          input string opaque key algorithm #1
+ * \param alg2          input string opaque key algorithm #2
+ * \param psa_alg1      output PSA algorithm #1
+ * \param psa_alg2      output PSA algorithm #2
+ * \param usage         output key usage
+ * \param key_type      key type used to set default psa algorithm/usage
+ *                      when alg1 in "none"
+ *
+ * \return              \c 0 on success.
+ * \return              \c 1 on parse failure.
+ */
+int key_opaque_set_alg_usage( const char *alg1, const char *alg2,
+                              psa_algorithm_t *psa_alg1,
+                              psa_algorithm_t *psa_alg2,
+                              psa_key_usage_t *usage,
+                              mbedtls_pk_type_t key_type );
+#endif /* MBEDTLS_USE_PSA_CRYPTO */
+
 #if defined(MBEDTLS_USE_PSA_CRYPTO) && defined(MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG)
 /* The test implementation of the PSA external RNG is insecure. When
  * MBEDTLS_PSA_CRYPTO_EXTERNAL_RNG is enabled, before using any PSA crypto