Add I/O buffer resizing in handshake init and free

Add a conditional buffer resizing feature. Introduce tests exercising
it in various setups (serialization, renegotiation, mfl manipulations).
Signed-off-by: Andrzej Kurek <andrzej.kurek@arm.com>
diff --git a/include/mbedtls/ssl_internal.h b/include/mbedtls/ssl_internal.h
index ba4a608..4720988 100644
--- a/include/mbedtls/ssl_internal.h
+++ b/include/mbedtls/ssl_internal.h
@@ -256,6 +256,32 @@
       + ( MBEDTLS_SSL_CID_OUT_LEN_MAX ) )
 #endif
 
+#if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
+static inline uint32_t mbedtls_ssl_get_output_buflen( const mbedtls_ssl_context *ctx )
+{
+#if defined (MBEDTLS_SSL_DTLS_CONNECTION_ID)
+    return (uint32_t) mbedtls_ssl_get_max_frag_len( ctx )
+               + MBEDTLS_SSL_HEADER_LEN + MBEDTLS_SSL_PAYLOAD_OVERHEAD
+               + MBEDTLS_SSL_CID_OUT_LEN_MAX;
+#else
+    return (uint32_t) mbedtls_ssl_get_max_frag_len( ctx )
+               + MBEDTLS_SSL_HEADER_LEN + MBEDTLS_SSL_PAYLOAD_OVERHEAD;
+#endif
+}
+
+static inline uint32_t mbedtls_ssl_get_input_buflen( const mbedtls_ssl_context *ctx )
+{
+#if defined (MBEDTLS_SSL_DTLS_CONNECTION_ID)
+    return (uint32_t) mbedtls_ssl_get_max_frag_len( ctx )
+               + MBEDTLS_SSL_HEADER_LEN + MBEDTLS_SSL_PAYLOAD_OVERHEAD
+               + MBEDTLS_SSL_CID_IN_LEN_MAX;
+#else
+    return (uint32_t) mbedtls_ssl_get_max_frag_len( ctx )
+               + MBEDTLS_SSL_HEADER_LEN + MBEDTLS_SSL_PAYLOAD_OVERHEAD;
+#endif
+}
+#endif
+
 #ifdef MBEDTLS_ZLIB_SUPPORT
 /* Compression buffer holds both IN and OUT buffers, so should be size of the larger */
 #define MBEDTLS_SSL_COMPRESS_BUFFER_LEN (                               \
diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 60ffa61..db2cdb6 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -245,6 +245,29 @@
     return( 0 );
 }
 
+#if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
+static int resize_buffer( unsigned char **buffer, size_t len_new, size_t *len_old )
+{
+    unsigned char* resized_buffer = mbedtls_calloc( 1, len_new );
+    if( resized_buffer == NULL )
+        return -1;
+
+    /* We want to copy len_new bytes when downsizing the buffer, and
+     * len_old bytes when upsizing, so we choose the smaller of two sizes,
+     * to fit one buffer into another. Size checks, ensuring that no data is
+     * lost, are done outside of this function. */
+    memcpy( resized_buffer, *buffer,
+            ( len_new < *len_old ) ? len_new : *len_old );
+    mbedtls_platform_zeroize( *buffer, *len_old );
+    mbedtls_free( *buffer );
+
+    *buffer = resized_buffer;
+    *len_old = len_new;
+
+    return 0;
+}
+#endif /* MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH */
+
 /*
  * Key material generation
  */
@@ -3643,6 +3666,43 @@
     {
         ssl->handshake = mbedtls_calloc( 1, sizeof(mbedtls_ssl_handshake_params) );
     }
+#if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
+    /* If the buffers are too small - reallocate */
+    {
+        int modified = 0;
+        if( ssl->in_buf_len < MBEDTLS_SSL_IN_BUFFER_LEN )
+        {
+            if( resize_buffer( &ssl->in_buf, MBEDTLS_SSL_IN_BUFFER_LEN,
+                               &ssl->in_buf_len ) != 0 )
+            {
+                MBEDTLS_SSL_DEBUG_MSG( 1, ( "input buffer resizing failed - out of memory" ) );
+            }
+            else
+            {
+                MBEDTLS_SSL_DEBUG_MSG( 2, ( "Reallocating in_buf to %d", MBEDTLS_SSL_IN_BUFFER_LEN ) );
+                modified = 1;
+            }
+        }
+        if( ssl->out_buf_len < MBEDTLS_SSL_OUT_BUFFER_LEN )
+        {
+            if( resize_buffer( &ssl->out_buf, MBEDTLS_SSL_OUT_BUFFER_LEN,
+                               &ssl->out_buf_len ) != 0 )
+            {
+                MBEDTLS_SSL_DEBUG_MSG( 1, ( "output buffer resizing failed - out of memory" ) );
+            }
+            else
+            {
+                MBEDTLS_SSL_DEBUG_MSG( 2, ( "Reallocating out_buf to %d", MBEDTLS_SSL_OUT_BUFFER_LEN ) );
+                modified = 1;
+            }
+        }
+        if( modified )
+        {
+            /* Update pointers here to avoid doing it twice. */
+            mbedtls_ssl_reset_in_out_pointers( ssl );
+        }
+    }
+#endif
 
     /* All pointers should exist and can be directly freed without issue */
     if( ssl->handshake == NULL ||
@@ -5818,6 +5878,60 @@
 
     mbedtls_platform_zeroize( handshake,
                               sizeof( mbedtls_ssl_handshake_params ) );
+
+#if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
+    /* If the buffers are too big - reallocate. Because of the way Mbed TLS
+     * processes datagrams and the fact that a datagram is allowed to have
+     * several records in it, it is possible that the I/O buffers are not
+     * empty at this stage */
+    {
+        int modified = 0;
+        uint32_t buf_len = mbedtls_ssl_get_input_buflen( ssl );
+        size_t written_in = 0;
+        size_t written_out = 0;
+        if( ssl->in_buf != NULL &&
+            ssl->in_buf_len > buf_len &&
+            ssl->in_left < buf_len )
+        {
+            written_in = ssl->in_msg - ssl->in_buf;
+            if( resize_buffer( &ssl->in_buf, buf_len, &ssl->in_buf_len ) != 0 )
+            {
+                MBEDTLS_SSL_DEBUG_MSG( 1, ( "input buffer resizing failed - out of memory" ) );
+            }
+            else
+            {
+                MBEDTLS_SSL_DEBUG_MSG( 2, ( "Reallocating in_buf to %d", buf_len ) );
+                modified = 1;
+            }
+        }
+
+        buf_len = mbedtls_ssl_get_output_buflen( ssl );
+        if( ssl->out_buf != NULL &&
+            ssl->out_buf_len > mbedtls_ssl_get_output_buflen( ssl ) &&
+            ssl->out_left < buf_len )
+        {
+            written_out = ssl->out_msg - ssl->out_buf;
+            if( resize_buffer( &ssl->out_buf, buf_len, &ssl->out_buf_len ) != 0 )
+            {
+                MBEDTLS_SSL_DEBUG_MSG( 1, ( "output buffer resizing failed - out of memory" ) );
+            }
+            else
+            {
+                MBEDTLS_SSL_DEBUG_MSG( 2, ( "Reallocating out_buf to %d", buf_len ) );
+                modified = 1;
+            }
+        }
+        if( modified )
+        {
+            /* Update pointers here to avoid doing it twice. */
+            mbedtls_ssl_reset_in_out_pointers( ssl );
+            /* Fields below might not be properly updated with record
+             * splitting, so they are manually updated here. */
+            ssl->out_msg = ssl->out_buf + written_out;
+            ssl->in_msg = ssl->in_buf + written_in;
+        }
+    }
+#endif
 }
 
 void mbedtls_ssl_session_free( mbedtls_ssl_session *session )
@@ -6499,12 +6613,14 @@
     {
         mbedtls_platform_zeroize( ssl->out_buf, out_buf_len );
         mbedtls_free( ssl->out_buf );
+        ssl->out_buf = NULL;
     }
 
     if( ssl->in_buf != NULL )
     {
         mbedtls_platform_zeroize( ssl->in_buf, in_buf_len );
         mbedtls_free( ssl->in_buf );
+        ssl->in_buf = NULL;
     }
 
 #if defined(MBEDTLS_ZLIB_SUPPORT)
diff --git a/tests/suites/test_suite_ssl.data b/tests/suites/test_suite_ssl.data
index 39a9d54..1a24105 100644
--- a/tests/suites/test_suite_ssl.data
+++ b/tests/suites/test_suite_ssl.data
@@ -368,6 +368,54 @@
 DTLS renegotiation: legacy break handshake
 renegotiation:MBEDTLS_SSL_LEGACY_BREAK_HANDSHAKE
 
+DTLS serialization with MFL=512
+resize_buffers_serialize_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_512
+
+DTLS serialization with MFL=1024
+resize_buffers_serialize_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_1024
+
+DTLS serialization with MFL=2048
+resize_buffers_serialize_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_2048
+
+DTLS serialization with MFL=4096
+resize_buffers_serialize_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_4096
+
+DTLS no legacy renegotiation with MFL=512
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_512:MBEDTLS_SSL_LEGACY_NO_RENEGOTIATION
+
+DTLS no legacy renegotiation with MFL=1024
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_1024:MBEDTLS_SSL_LEGACY_NO_RENEGOTIATION
+
+DTLS no legacy renegotiation with MFL=2048
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_2048:MBEDTLS_SSL_LEGACY_NO_RENEGOTIATION
+
+DTLS no legacy renegotiation with MFL=4096
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_4096:MBEDTLS_SSL_LEGACY_NO_RENEGOTIATION
+
+DTLS legacy allow renegotiation with MFL=512
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_512:MBEDTLS_SSL_LEGACY_ALLOW_RENEGOTIATION
+
+DTLS legacy allow renegotiation with MFL=1024
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_1024:MBEDTLS_SSL_LEGACY_ALLOW_RENEGOTIATION
+
+DTLS legacy allow renegotiation with MFL=2048
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_2048:MBEDTLS_SSL_LEGACY_ALLOW_RENEGOTIATION
+
+DTLS legacy allow renegotiation with MFL=4096
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_4096:MBEDTLS_SSL_LEGACY_ALLOW_RENEGOTIATION
+
+DTLS legacy break handshake renegotiation with MFL=512
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_512:MBEDTLS_SSL_LEGACY_BREAK_HANDSHAKE
+
+DTLS legacy break handshake renegotiation with MFL=1024
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_1024:MBEDTLS_SSL_LEGACY_BREAK_HANDSHAKE
+
+DTLS legacy break handshake renegotiation with MFL=2048
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_2048:MBEDTLS_SSL_LEGACY_BREAK_HANDSHAKE
+
+DTLS legacy break handshake renegotiation with MFL=4096
+resize_buffers_renegotiate_mfl:MBEDTLS_SSL_MAX_FRAG_LEN_4096:MBEDTLS_SSL_LEGACY_BREAK_HANDSHAKE
+
 SSL DTLS replay: initial state, seqnum 0
 ssl_dtls_replay:"":"000000000000":0
 
diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function
index 5485d9e..9b61a50 100644
--- a/tests/suites/test_suite_ssl.function
+++ b/tests/suites/test_suite_ssl.function
@@ -55,6 +55,7 @@
     void *cli_log_obj;
     void (*srv_log_fun)(void *, int, const char *, int, const char *);
     void (*cli_log_fun)(void *, int, const char *, int, const char *);
+    int resize_buffers;
 } handshake_test_options;
 
 void init_handshake_options( handshake_test_options *opts )
@@ -77,6 +78,7 @@
   opts->srv_log_obj = NULL;
   opts->srv_log_fun = NULL;
   opts->cli_log_fun = NULL;
+  opts->resize_buffers = 1;
 }
 /*
  * Buffer structure for custom I/O callbacks.
@@ -1767,6 +1769,17 @@
                                               &(server.socket),
                                               BUFFSIZE ) == 0 );
 
+#if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
+    if( options->resize_buffers != 0 )
+    {
+        /* Ensure that the buffer sizes are appropriate before resizes */
+        TEST_ASSERT( client.ssl.out_buf_len == MBEDTLS_SSL_OUT_BUFFER_LEN );
+        TEST_ASSERT( client.ssl.in_buf_len == MBEDTLS_SSL_IN_BUFFER_LEN );
+        TEST_ASSERT( server.ssl.out_buf_len == MBEDTLS_SSL_OUT_BUFFER_LEN );
+        TEST_ASSERT( server.ssl.in_buf_len == MBEDTLS_SSL_IN_BUFFER_LEN );
+    }
+#endif
+
     TEST_ASSERT( mbedtls_move_handshake_to_state( &(client.ssl),
                                                   &(server.ssl),
                                                   MBEDTLS_SSL_HANDSHAKE_OVER )
@@ -1774,6 +1787,31 @@
     TEST_ASSERT( client.ssl.state == MBEDTLS_SSL_HANDSHAKE_OVER );
     TEST_ASSERT( server.ssl.state == MBEDTLS_SSL_HANDSHAKE_OVER );
 
+#if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
+    if( options->resize_buffers != 0 )
+    {
+        /* Note - the case below will have to updated, since due to a 1n-1
+         * split against BEAST the fragment count is different
+         * than expected when preparing the fragment counting code. */
+        if( options->version != MBEDTLS_SSL_MINOR_VERSION_0 &&
+            options->version != MBEDTLS_SSL_MINOR_VERSION_1 )
+        {
+            /* A server, when using DTLS, might delay a buffer resize to happen
+             * after it receives a message, so we force it. */
+            TEST_ASSERT( exchange_data( &(client.ssl), &(server.ssl) ) == 0 );
+
+            TEST_ASSERT( client.ssl.out_buf_len ==
+                         mbedtls_ssl_get_output_buflen( &client.ssl ) );
+            TEST_ASSERT( client.ssl.in_buf_len ==
+                         mbedtls_ssl_get_input_buflen( &client.ssl ) );
+            TEST_ASSERT( server.ssl.out_buf_len ==
+                         mbedtls_ssl_get_output_buflen( &server.ssl ) );
+            TEST_ASSERT( server.ssl.in_buf_len ==
+                         mbedtls_ssl_get_input_buflen( &server.ssl ) );
+        }
+    }
+#endif
+
     if( options->cli_msg_len != 0 || options->srv_msg_len != 0 )
     {
         /* Start data exchanging test */
@@ -1814,9 +1852,27 @@
                                   mbedtls_timing_set_delay,
                                   mbedtls_timing_get_delay );
 #endif
+#if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
+        if( options->resize_buffers != 0 )
+        {
+            /* Ensure that the buffer sizes are appropriate before resizes */
+            TEST_ASSERT( server.ssl.out_buf_len == MBEDTLS_SSL_OUT_BUFFER_LEN );
+            TEST_ASSERT( server.ssl.in_buf_len == MBEDTLS_SSL_IN_BUFFER_LEN );
+        }
+#endif
         TEST_ASSERT( mbedtls_ssl_context_load( &( server.ssl ), context_buf,
                                                context_buf_len ) == 0 );
 
+#if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
+        /* Validate buffer sizes after context deserialization */
+        if( options->resize_buffers != 0 )
+        {
+            TEST_ASSERT( server.ssl.out_buf_len ==
+                         mbedtls_ssl_get_output_buflen( &server.ssl ) );
+            TEST_ASSERT( server.ssl.in_buf_len ==
+                         mbedtls_ssl_get_input_buflen( &server.ssl ) );
+        }
+#endif
         /* Retest writing/reading */
         if( options->cli_msg_len != 0 || options->srv_msg_len != 0 )
         {
@@ -1830,6 +1886,7 @@
         }
     }
 #endif /* MBEDTLS_SSL_CONTEXT_SERIALIZATION */
+
 #if defined(MBEDTLS_SSL_RENEGOTIATION)
     if( options->renegotiate )
     {
@@ -1859,6 +1916,14 @@
          * function will return waiting error on the socket. All rest of
          * renegotiation should happen during data exchanging */
         ret = mbedtls_ssl_renegotiate( &(client.ssl) );
+#if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
+        if( options->resize_buffers != 0 )
+        {
+            /* Ensure that the buffer sizes are appropriate before resizes */
+            TEST_ASSERT( client.ssl.out_buf_len == MBEDTLS_SSL_OUT_BUFFER_LEN );
+            TEST_ASSERT( client.ssl.in_buf_len == MBEDTLS_SSL_IN_BUFFER_LEN );
+        }
+#endif
         TEST_ASSERT( ret == 0 ||
                      ret == MBEDTLS_ERR_SSL_WANT_READ ||
                      ret == MBEDTLS_ERR_SSL_WANT_WRITE );
@@ -1872,6 +1937,20 @@
                      MBEDTLS_SSL_RENEGOTIATION_DONE );
         TEST_ASSERT( client.ssl.renego_status ==
                      MBEDTLS_SSL_RENEGOTIATION_DONE );
+#if defined(MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH)
+        /* Validate buffer sizes after renegotiation */
+        if( options->resize_buffers != 0 )
+        {
+            TEST_ASSERT( client.ssl.out_buf_len ==
+                         mbedtls_ssl_get_output_buflen( &client.ssl ) );
+            TEST_ASSERT( client.ssl.in_buf_len ==
+                         mbedtls_ssl_get_input_buflen( &client.ssl ) );
+            TEST_ASSERT( server.ssl.out_buf_len ==
+                         mbedtls_ssl_get_output_buflen( &server.ssl ) );
+            TEST_ASSERT( server.ssl.in_buf_len ==
+                         mbedtls_ssl_get_input_buflen( &server.ssl ) );
+        }
+#endif /* MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH */
     }
 #endif /* MBEDTLS_SSL_RENEGOTIATION */
 
@@ -3797,3 +3876,43 @@
     goto exit;
 }
 /* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_X509_CRT_PARSE_C:!MBEDTLS_USE_PSA_CRYPTO:MBEDTLS_PKCS1_V15:MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH:MBEDTLS_SSL_PROTO_TLS1_2:MBEDTLS_RSA_C:MBEDTLS_ECP_DP_SECP384R1_ENABLED */
+void resize_buffers( int mfl, int renegotiation, int legacy_renegotiation,
+                     int serialize, int dtls )
+{
+    handshake_test_options options;
+    init_handshake_options( &options );
+
+    options.mfl = mfl;
+    options.renegotiate = renegotiation;
+    options.legacy_renegotiation = legacy_renegotiation;
+    options.serialize = serialize;
+    options.dtls = dtls;
+    options.resize_buffers = 1;
+
+    perform_handshake( &options );
+    /* The goto below is used to avoid an "unused label" warning.*/
+    goto exit;
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_X509_CRT_PARSE_C:!MBEDTLS_USE_PSA_CRYPTO:MBEDTLS_PKCS1_V15:MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH:MBEDTLS_SSL_CONTEXT_SERIALIZATION:MBEDTLS_SSL_PROTO_TLS1_2:MBEDTLS_RSA_C:MBEDTLS_ECP_DP_SECP384R1_ENABLED:MBEDTLS_SSL_PROTO_DTLS */
+void resize_buffers_serialize_mfl( int mfl )
+{
+    test_resize_buffers( mfl, 0, MBEDTLS_SSL_LEGACY_NO_RENEGOTIATION, 1, 1 );
+
+    /* The goto below is used to avoid an "unused label" warning.*/
+    goto exit;
+}
+/* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_X509_CRT_PARSE_C:!MBEDTLS_USE_PSA_CRYPTO:MBEDTLS_PKCS1_V15:MBEDTLS_SSL_VARIABLE_BUFFER_LENGTH:MBEDTLS_SSL_RENEGOTIATION:MBEDTLS_SSL_PROTO_TLS1_2:MBEDTLS_RSA_C:MBEDTLS_ECP_DP_SECP384R1_ENABLED */
+void resize_buffers_renegotiate_mfl( int mfl, int legacy_renegotiation )
+{
+    test_resize_buffers( mfl, 1, legacy_renegotiation, 0, 1 );
+
+    /* The goto below is used to avoid an "unused label" warning.*/
+    goto exit;
+}
+/* END_CASE */