Merge remote-tracking branch 'origin/development' into merge-dev
diff --git a/ChangeLog.d/fix-in-cid-buffer-size.txt b/ChangeLog.d/fix-in-cid-buffer-size.txt
new file mode 100644
index 0000000..8a6c850
--- /dev/null
+++ b/ChangeLog.d/fix-in-cid-buffer-size.txt
@@ -0,0 +1,4 @@
+Security
+    * Fix potential heap buffer overread and overwrite in DTLS if
+      MBEDTLS_SSL_DTLS_CONNECTION_ID is enabled and
+      MBEDTLS_SSL_CID_IN_LEN_MAX > 2 * MBEDTLS_SSL_CID_OUT_LEN_MAX.
diff --git a/ChangeLog.d/rsa-fix-priviliged-side-channel.txt b/ChangeLog.d/rsa-fix-priviliged-side-channel.txt
new file mode 100644
index 0000000..bafe18d
--- /dev/null
+++ b/ChangeLog.d/rsa-fix-priviliged-side-channel.txt
@@ -0,0 +1,10 @@
+Security
+   * An adversary with access to precise enough information about memory
+     accesses (typically, an untrusted operating system attacking a secure
+     enclave) could recover an RSA private key after observing the victim
+     performing a single private-key operation if the window size used for the
+     exponentiation was 3 or smaller. Found and reported by Zili KOU,
+     Wenjian HE, Sharad Sinha, and Wei ZHANG. See "Cache Side-channel Attacks
+     and Defenses of the Sliding Window Algorithm in TEEs" - Design, Automation
+     and Test in Europe 2023.
+
diff --git a/library/bignum.c b/library/bignum.c
index a68957a..65708c9 100644
--- a/library/bignum.c
+++ b/library/bignum.c
@@ -1590,11 +1590,11 @@
                          mbedtls_mpi *prec_RR )
 {
     int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
-    size_t wbits, wsize, one = 1;
+    size_t window_bitsize;
     size_t i, j, nblimbs;
     size_t bufsize, nbits;
     mbedtls_mpi_uint ei, mm, state;
-    mbedtls_mpi RR, T, W[ 1 << MBEDTLS_MPI_WINDOW_SIZE ], WW, Apos;
+    mbedtls_mpi RR, T, W[ (size_t) 1 << MBEDTLS_MPI_WINDOW_SIZE ], WW, Apos;
     int neg;
 
     MPI_VALIDATE_RET( X != NULL );
@@ -1623,21 +1623,59 @@
 
     i = mbedtls_mpi_bitlen( E );
 
-    wsize = ( i > 671 ) ? 6 : ( i > 239 ) ? 5 :
+    window_bitsize = ( i > 671 ) ? 6 : ( i > 239 ) ? 5 :
             ( i >  79 ) ? 4 : ( i >  23 ) ? 3 : 1;
 
 #if( MBEDTLS_MPI_WINDOW_SIZE < 6 )
-    if( wsize > MBEDTLS_MPI_WINDOW_SIZE )
-        wsize = MBEDTLS_MPI_WINDOW_SIZE;
+    if( window_bitsize > MBEDTLS_MPI_WINDOW_SIZE )
+        window_bitsize = MBEDTLS_MPI_WINDOW_SIZE;
 #endif
 
+    const size_t w_table_used_size = (size_t) 1 << window_bitsize;
+
+    /*
+     * This function is not constant-trace: its memory accesses depend on the
+     * exponent value. To defend against timing attacks, callers (such as RSA
+     * and DHM) should use exponent blinding. However this is not enough if the
+     * adversary can find the exponent in a single trace, so this function
+     * takes extra precautions against adversaries who can observe memory
+     * access patterns.
+     *
+     * This function performs a series of multiplications by table elements and
+     * squarings, and we want the prevent the adversary from finding out which
+     * table element was used, and from distinguishing between multiplications
+     * and squarings. Firstly, when multiplying by an element of the window
+     * W[i], we do a constant-trace table lookup to obfuscate i. This leaves
+     * squarings as having a different memory access patterns from other
+     * multiplications. So secondly, we put the accumulator X in the table as
+     * well, and also do a constant-trace table lookup to multiply by X.
+     *
+     * This way, all multiplications take the form of a lookup-and-multiply.
+     * The number of lookup-and-multiply operations inside each iteration of
+     * the main loop still depends on the bits of the exponent, but since the
+     * other operations in the loop don't have an easily recognizable memory
+     * trace, an adversary is unlikely to be able to observe the exact
+     * patterns.
+     *
+     * An adversary may still be able to recover the exponent if they can
+     * observe both memory accesses and branches. However, branch prediction
+     * exploitation typically requires many traces of execution over the same
+     * data, which is defeated by randomized blinding.
+     *
+     * To achieve this, we make a copy of X and we use the table entry in each
+     * calculation from this point on.
+     */
+    const size_t x_index = 0;
+    mbedtls_mpi_init( &W[x_index] );
+    mbedtls_mpi_copy( &W[x_index], X );
+
     j = N->n + 1;
     /* All W[i] and X must have at least N->n limbs for the mpi_montmul()
      * and mpi_montred() calls later. Here we ensure that W[1] and X are
      * large enough, and later we'll grow other W[i] to the same length.
      * They must not be shrunk midway through this function!
      */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_grow( X, j ) );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &W[x_index], j ) );
     MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &W[1],  j ) );
     MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &T, j * 2 ) );
 
@@ -1686,28 +1724,36 @@
     mpi_montmul( &W[1], &RR, N, mm, &T );
 
     /*
-     * X = R^2 * R^-1 mod N = R mod N
+     * W[x_index] = R^2 * R^-1 mod N = R mod N
      */
-    MBEDTLS_MPI_CHK( mbedtls_mpi_copy( X, &RR ) );
-    mpi_montred( X, N, mm, &T );
+    MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[x_index], &RR ) );
+    mpi_montred( &W[x_index], N, mm, &T );
 
-    if( wsize > 1 )
+
+    if( window_bitsize > 1 )
     {
         /*
-         * W[1 << (wsize - 1)] = W[1] ^ (wsize - 1)
+         * W[i] = W[1] ^ i
+         *
+         * The first bit of the sliding window is always 1 and therefore we
+         * only need to store the second half of the table.
+         *
+         * (There are two special elements in the table: W[0] for the
+         * accumulator/result and W[1] for A in Montgomery form. Both of these
+         * are already set at this point.)
          */
-        j =  one << ( wsize - 1 );
+        j = w_table_used_size / 2;
 
         MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &W[j], N->n + 1 ) );
         MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[j], &W[1]    ) );
 
-        for( i = 0; i < wsize - 1; i++ )
+        for( i = 0; i < window_bitsize - 1; i++ )
             mpi_montmul( &W[j], &W[j], N, mm, &T );
 
         /*
          * W[i] = W[i - 1] * W[1]
          */
-        for( i = j + 1; i < ( one << wsize ); i++ )
+        for( i = j + 1; i < w_table_used_size; i++ )
         {
             MBEDTLS_MPI_CHK( mbedtls_mpi_grow( &W[i], N->n + 1 ) );
             MBEDTLS_MPI_CHK( mbedtls_mpi_copy( &W[i], &W[i - 1] ) );
@@ -1719,7 +1765,7 @@
     nblimbs = E->n;
     bufsize = 0;
     nbits   = 0;
-    wbits   = 0;
+    size_t exponent_bits_in_window = 0;
     state   = 0;
 
     while( 1 )
@@ -1747,9 +1793,10 @@
         if( ei == 0 && state == 1 )
         {
             /*
-             * out of window, square X
+             * out of window, square W[x_index]
              */
-            mpi_montmul( X, X, N, mm, &T );
+            MBEDTLS_MPI_CHK( mpi_select( &WW, W, w_table_used_size, x_index ) );
+            mpi_montmul( &W[x_index], &WW, N, mm, &T );
             continue;
         }
 
@@ -1759,25 +1806,30 @@
         state = 2;
 
         nbits++;
-        wbits |= ( ei << ( wsize - nbits ) );
+        exponent_bits_in_window |= ( ei << ( window_bitsize - nbits ) );
 
-        if( nbits == wsize )
+        if( nbits == window_bitsize )
         {
             /*
-             * X = X^wsize R^-1 mod N
+             * W[x_index] = W[x_index]^window_bitsize R^-1 mod N
              */
-            for( i = 0; i < wsize; i++ )
-                mpi_montmul( X, X, N, mm, &T );
+            for( i = 0; i < window_bitsize; i++ )
+            {
+                MBEDTLS_MPI_CHK( mpi_select( &WW, W, w_table_used_size,
+                                             x_index ) );
+                mpi_montmul( &W[x_index], &WW, N, mm, &T );
+            }
 
             /*
-             * X = X * W[wbits] R^-1 mod N
+             * W[x_index] = W[x_index] * W[exponent_bits_in_window] R^-1 mod N
              */
-            MBEDTLS_MPI_CHK( mpi_select( &WW, W, (size_t) 1 << wsize, wbits ) );
-            mpi_montmul( X, &WW, N, mm, &T );
+            MBEDTLS_MPI_CHK( mpi_select( &WW, W, w_table_used_size,
+                                         exponent_bits_in_window ) );
+            mpi_montmul( &W[x_index], &WW, N, mm, &T );
 
             state--;
             nbits = 0;
-            wbits = 0;
+            exponent_bits_in_window = 0;
         }
     }
 
@@ -1786,31 +1838,45 @@
      */
     for( i = 0; i < nbits; i++ )
     {
-        mpi_montmul( X, X, N, mm, &T );
+        MBEDTLS_MPI_CHK( mpi_select( &WW, W, w_table_used_size, x_index ) );
+        mpi_montmul( &W[x_index], &WW, N, mm, &T );
 
-        wbits <<= 1;
+        exponent_bits_in_window <<= 1;
 
-        if( ( wbits & ( one << wsize ) ) != 0 )
-            mpi_montmul( X, &W[1], N, mm, &T );
+        if( ( exponent_bits_in_window & ( (size_t) 1 << window_bitsize ) ) != 0 )
+        {
+            MBEDTLS_MPI_CHK( mpi_select( &WW, W, w_table_used_size, 1 ) );
+            mpi_montmul( &W[x_index], &WW, N, mm, &T );
+        }
     }
 
     /*
-     * X = A^E * R * R^-1 mod N = A^E mod N
+     * W[x_index] = A^E * R * R^-1 mod N = A^E mod N
      */
-    mpi_montred( X, N, mm, &T );
+    mpi_montred( &W[x_index], N, mm, &T );
 
     if( neg && E->n != 0 && ( E->p[0] & 1 ) != 0 )
     {
-        X->s = -1;
-        MBEDTLS_MPI_CHK( mbedtls_mpi_add_mpi( X, N, X ) );
+        W[x_index].s = -1;
+        MBEDTLS_MPI_CHK( mbedtls_mpi_add_mpi( &W[x_index], N, &W[x_index] ) );
     }
 
+    /*
+     * Load the result in the output variable.
+     */
+    mbedtls_mpi_copy( X, &W[x_index] );
+
 cleanup:
 
-    for( i = ( one << ( wsize - 1 ) ); i < ( one << wsize ); i++ )
+    /* The first bit of the sliding window is always 1 and therefore the first
+     * half of the table was unused. */
+    for( i = w_table_used_size/2; i < w_table_used_size; i++ )
         mbedtls_mpi_free( &W[i] );
 
-    mbedtls_mpi_free( &W[1] ); mbedtls_mpi_free( &T ); mbedtls_mpi_free( &Apos );
+    mbedtls_mpi_free( &W[x_index] );
+    mbedtls_mpi_free( &W[1] );
+    mbedtls_mpi_free( &T );
+    mbedtls_mpi_free( &Apos );
     mbedtls_mpi_free( &WW );
 
     if( prec_RR == NULL || prec_RR->p == NULL )
diff --git a/library/ssl_misc.h b/library/ssl_misc.h
index 53d50f2..8388b63 100644
--- a/library/ssl_misc.h
+++ b/library/ssl_misc.h
@@ -1135,7 +1135,7 @@
 #if defined(MBEDTLS_SSL_DTLS_CONNECTION_ID)
     uint8_t in_cid_len;
     uint8_t out_cid_len;
-    unsigned char in_cid [ MBEDTLS_SSL_CID_OUT_LEN_MAX ];
+    unsigned char in_cid [ MBEDTLS_SSL_CID_IN_LEN_MAX ];
     unsigned char out_cid[ MBEDTLS_SSL_CID_OUT_LEN_MAX ];
 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */