Refactor test_suite_ssl tests to enable cache setting
Signed-off-by: Andrzej Kurek <andrzej.kurek@arm.com>
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index efec5ea..2101ed8 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -9,6 +9,10 @@
 #include <ssl_tls13_invasive.h>
 #include "test/certs.h"
 
+#if defined(MBEDTLS_SSL_CACHE_C)
+#include "mbedtls/ssl_cache.h"
+#endif
+
 #include <psa/crypto.h>
 
 #include <constant_time_internal.h>
@@ -82,6 +86,9 @@
     void (*srv_log_fun)(void *, int, const char *, int, const char *);
     void (*cli_log_fun)(void *, int, const char *, int, const char *);
     int resize_buffers;
+#if defined(MBEDTLS_SSL_CACHE_C)
+    mbedtls_ssl_cache_context* cache;
+#endif
 } handshake_test_options;
 
 void init_handshake_options( handshake_test_options *opts )
@@ -114,6 +121,20 @@
   opts->srv_log_fun = NULL;
   opts->cli_log_fun = NULL;
   opts->resize_buffers = 1;
+#if defined(MBEDTLS_SSL_CACHE_C)
+  opts->cache = malloc( sizeof( mbedtls_ssl_cache_context ) );
+  mbedtls_ssl_cache_init( opts->cache );
+#endif
+}
+
+void free_handshake_options( handshake_test_options *opts )
+{
+#if defined(MBEDTLS_SSL_CACHE_C)
+    mbedtls_ssl_cache_free( opts->cache );
+    free( opts->cache );
+#else
+    (void) opts;
+#endif
 }
 /*
  * Buffer structure for custom I/O callbacks.
@@ -918,9 +939,8 @@
  *
  * \retval  0 on success, otherwise error code.
  */
-
-int mbedtls_endpoint_init( mbedtls_endpoint *ep, int endpoint_type, int pk_alg,
-                           int opaque_alg, int opaque_alg2, int opaque_usage,
+int mbedtls_endpoint_init( mbedtls_endpoint *ep, int endpoint_type,
+                           handshake_test_options* options,
                            mbedtls_test_message_socket_context *dtls_context,
                            mbedtls_test_message_queue *input_queue,
                            mbedtls_test_message_queue *output_queue,
@@ -1002,6 +1022,15 @@
 
     mbedtls_ssl_conf_authmode( &( ep->conf ), MBEDTLS_SSL_VERIFY_REQUIRED );
 
+#if defined(MBEDTLS_SSL_CACHE_C)
+    if( endpoint_type == MBEDTLS_SSL_IS_SERVER && options->cache != NULL )
+    {
+        mbedtls_ssl_conf_session_cache( &( ep->conf ), options->cache,
+                                   mbedtls_ssl_cache_get,
+                                   mbedtls_ssl_cache_set );
+    }
+#endif
+
     ret = mbedtls_ssl_setup( &( ep->ssl ), &( ep->conf ) );
     TEST_ASSERT( ret == 0 );
 
@@ -1010,8 +1039,10 @@
          mbedtls_ssl_conf_dtls_cookies( &( ep->conf ), NULL, NULL, NULL );
 #endif
 
-    ret = mbedtls_endpoint_certificate_init( ep, pk_alg, opaque_alg,
-                                             opaque_alg2, opaque_usage );
+    ret = mbedtls_endpoint_certificate_init( ep, options->pk_alg,
+                                             options->opaque_alg,
+                                             options->opaque_alg2,
+                                             options->opaque_usage );
     TEST_ASSERT( ret == 0 );
 
     TEST_EQUAL( mbedtls_ssl_conf_get_user_data_n( &ep->conf ), user_data_n );
@@ -1984,11 +2015,7 @@
     if( options->dtls != 0 )
     {
         TEST_ASSERT( mbedtls_endpoint_init( &client, MBEDTLS_SSL_IS_CLIENT,
-                                            options->pk_alg,
-                                            options->opaque_alg,
-                                            options->opaque_alg2,
-                                            options->opaque_usage,
-                                            &client_context,
+                                            options, &client_context,
                                             &client_queue,
                                             &server_queue, NULL ) == 0 );
 #if defined(MBEDTLS_TIMING_C)
@@ -2000,11 +2027,7 @@
     else
     {
         TEST_ASSERT( mbedtls_endpoint_init( &client, MBEDTLS_SSL_IS_CLIENT,
-                                            options->pk_alg,
-                                            options->opaque_alg,
-                                            options->opaque_alg2,
-                                            options->opaque_usage,
-                                            NULL, NULL,
+                                            options, NULL, NULL,
                                             NULL, NULL ) == 0 );
     }
 
@@ -2038,11 +2061,7 @@
     if( options->dtls != 0 )
     {
         TEST_ASSERT( mbedtls_endpoint_init( &server, MBEDTLS_SSL_IS_SERVER,
-                                            options->pk_alg,
-                                            options->opaque_alg,
-                                            options->opaque_alg2,
-                                            options->opaque_usage,
-                                            &server_context,
+                                            options, &server_context,
                                             &server_queue,
                                             &client_queue, NULL ) == 0 );
 #if defined(MBEDTLS_TIMING_C)
@@ -2054,12 +2073,8 @@
     else
     {
         TEST_ASSERT( mbedtls_endpoint_init( &server, MBEDTLS_SSL_IS_SERVER,
-                                            options->pk_alg,
-                                            options->opaque_alg,
-                                            options->opaque_alg2,
-                                            options->opaque_usage,
-                                            NULL, NULL,
-                                            NULL, NULL ) == 0 );
+                                            options, NULL, NULL, NULL,
+                                            NULL ) == 0 );
     }
 
     mbedtls_ssl_conf_authmode( &server.conf, options->srv_auth_mode );
@@ -4771,20 +4786,24 @@
     enum { BUFFSIZE = 1024 };
     mbedtls_endpoint ep;
     int ret = -1;
+    handshake_test_options options;
+    init_handshake_options( &options );
+    options.pk_alg = MBEDTLS_PK_RSA;
 
-    ret = mbedtls_endpoint_init( NULL, endpoint_type, MBEDTLS_PK_RSA,
-                                 0, 0, 0, NULL, NULL, NULL, NULL );
+    ret = mbedtls_endpoint_init( NULL, endpoint_type, &options,
+                                 NULL, NULL, NULL, NULL );
     TEST_ASSERT( MBEDTLS_ERR_SSL_BAD_INPUT_DATA == ret );
 
-    ret = mbedtls_endpoint_certificate_init( NULL, MBEDTLS_PK_RSA, 0, 0, 0 );
+    ret = mbedtls_endpoint_certificate_init( NULL, options.pk_alg, 0, 0, 0 );
     TEST_ASSERT( MBEDTLS_ERR_SSL_BAD_INPUT_DATA == ret );
 
-    ret = mbedtls_endpoint_init( &ep, endpoint_type, MBEDTLS_PK_RSA, 0, 0, 0,
+    ret = mbedtls_endpoint_init( &ep, endpoint_type, &options,
                                  NULL, NULL, NULL, NULL );
     TEST_ASSERT( ret == 0 );
 
 exit:
     mbedtls_endpoint_free( &ep, NULL );
+    free_handshake_options( &options );
 }
 /* END_CASE */
 
@@ -4794,18 +4813,21 @@
     enum { BUFFSIZE = 1024 };
     mbedtls_endpoint base_ep, second_ep;
     int ret = -1;
+    handshake_test_options options;
+    init_handshake_options( &options );
+    options.pk_alg = MBEDTLS_PK_RSA;
 
     USE_PSA_INIT( );
 
-    ret = mbedtls_endpoint_init( &base_ep, endpoint_type, MBEDTLS_PK_RSA,
-                                 0, 0, 0, NULL, NULL, NULL, NULL );
+    ret = mbedtls_endpoint_init( &base_ep, endpoint_type, &options,
+                                 NULL, NULL, NULL, NULL );
     TEST_ASSERT( ret == 0 );
 
     ret = mbedtls_endpoint_init( &second_ep,
                             ( endpoint_type == MBEDTLS_SSL_IS_SERVER ) ?
                             MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
-                                 MBEDTLS_PK_RSA, 0, 0, 0,
-                                 NULL, NULL, NULL, NULL );
+                                 &options, NULL, NULL, NULL, NULL );
+
     TEST_ASSERT( ret == 0 );
 
     ret = mbedtls_mock_socket_connect( &(base_ep.socket),
@@ -4832,6 +4854,7 @@
     }
 
 exit:
+    free_handshake_options( &options );
     mbedtls_endpoint_free( &base_ep, NULL );
     mbedtls_endpoint_free( &second_ep, NULL );
     USE_PSA_DONE( );
@@ -4931,8 +4954,12 @@
 #endif
 
     perform_handshake( &options );
+
     /* The goto below is used to avoid an "unused label" warning.*/
     goto exit;
+
+exit:
+    free_handshake_options( &options );
 }
 /* END_CASE */
 
@@ -5007,6 +5034,9 @@
     {
         TEST_ASSERT( cli_pattern.counter >= 1 );
     }
+
+exit:
+    free_handshake_options( &options );
 }
 /* END_CASE */
 
@@ -5021,8 +5051,12 @@
     options.dtls = 1;
 
     perform_handshake( &options );
+
     /* The goto below is used to avoid an "unused label" warning.*/
     goto exit;
+
+exit:
+    free_handshake_options( &options );
 }
 /* END_CASE */
 
@@ -5042,8 +5076,11 @@
     options.resize_buffers = 1;
 
     perform_handshake( &options );
+
     /* The goto below is used to avoid an "unused label" warning.*/
     goto exit;
+exit:
+    free_handshake_options( &options );
 }
 /* END_CASE */
 
@@ -5055,6 +5092,9 @@
 
     /* The goto below is used to avoid an "unused label" warning.*/
     goto exit;
+
+exit:
+    free_handshake_options( &options );
 }
 /* END_CASE */
 
@@ -5066,6 +5106,9 @@
 
     /* The goto below is used to avoid an "unused label" warning.*/
     goto exit;
+
+exit:
+    free_handshake_options( &options );
 }
 /* END_CASE */