Merge pull request #6547 from yanesca/extract_mod_exp_from_prototype

Bignum: Extract mod exp from prototype
diff --git a/library/bignum_core.c b/library/bignum_core.c
index 41d3239..6635351 100644
--- a/library/bignum_core.c
+++ b/library/bignum_core.c
@@ -540,6 +540,7 @@
     return( ret );
 }
 
+MBEDTLS_STATIC_TESTABLE
 void mbedtls_mpi_core_ct_uint_table_lookup( mbedtls_mpi_uint *dest,
                                             const mbedtls_mpi_uint *table,
                                             size_t limbs,
@@ -582,6 +583,162 @@
 
 /* BEGIN MERGE SLOT 1 */
 
+static size_t exp_mod_get_window_size( size_t Ebits )
+{
+    size_t wsize = ( Ebits > 671 ) ? 6 : ( Ebits > 239 ) ? 5 :
+                   ( Ebits >  79 ) ? 4 : 1;
+
+#if( MBEDTLS_MPI_WINDOW_SIZE < 6 )
+    if( wsize > MBEDTLS_MPI_WINDOW_SIZE )
+        wsize = MBEDTLS_MPI_WINDOW_SIZE;
+#endif
+
+    return( wsize );
+}
+
+static void exp_mod_precompute_window( const mbedtls_mpi_uint *A,
+                                       const mbedtls_mpi_uint *N,
+                                       size_t AN_limbs,
+                                       mbedtls_mpi_uint mm,
+                                       const mbedtls_mpi_uint *RR,
+                                       size_t welem,
+                                       mbedtls_mpi_uint *Wtable,
+                                       mbedtls_mpi_uint *temp )
+{
+    /* W[0] = 1 (in Montgomery presentation) */
+    memset( Wtable, 0, AN_limbs * ciL );
+    Wtable[0] = 1;
+    mbedtls_mpi_core_montmul( Wtable, Wtable, RR, AN_limbs, N, AN_limbs, mm, temp );
+
+    /* W[1] = A * R^2 * R^-1 mod N = A * R mod N */
+    mbedtls_mpi_uint *W1 = Wtable + AN_limbs;
+    mbedtls_mpi_core_montmul( W1, A, RR, AN_limbs, N, AN_limbs, mm, temp );
+
+    /* W[i+1] = W[i] * W[1], i >= 2 */
+    mbedtls_mpi_uint *Wprev = W1;
+    for( size_t i = 2; i < welem; i++ )
+    {
+        mbedtls_mpi_uint *Wcur = Wprev + AN_limbs;
+        mbedtls_mpi_core_montmul( Wcur, Wprev, W1, AN_limbs, N, AN_limbs, mm, temp );
+        Wprev = Wcur;
+    }
+}
+
+/* Exponentiation: X := A^E mod N.
+ *
+ * As in other bignum functions, assume that AN_limbs and E_limbs are nonzero.
+ *
+ * RR must contain 2^{2*biL} mod N.
+ *
+ * The algorithm is a variant of Left-to-right k-ary exponentiation: HAC 14.82
+ * (The difference is that the body in our loop processes a single bit instead
+ * of a full window.)
+ */
+int mbedtls_mpi_core_exp_mod( mbedtls_mpi_uint *X,
+                              const mbedtls_mpi_uint *A,
+                              const mbedtls_mpi_uint *N,
+                              size_t AN_limbs,
+                              const mbedtls_mpi_uint *E,
+                              size_t E_limbs,
+                              const mbedtls_mpi_uint *RR )
+{
+    const size_t wsize = exp_mod_get_window_size( E_limbs * biL );
+    const size_t welem = ( (size_t) 1 ) << wsize;
+
+    /* Allocate memory pool and set pointers to parts of it */
+    const size_t table_limbs   = welem * AN_limbs;
+    const size_t temp_limbs    = 2 * AN_limbs + 1;
+    const size_t select_limbs  = AN_limbs;
+    const size_t total_limbs   = table_limbs + temp_limbs + select_limbs;
+
+    /* heap allocated memory pool */
+    mbedtls_mpi_uint *mempool =
+        mbedtls_calloc( total_limbs, sizeof(mbedtls_mpi_uint) );
+    if( mempool == NULL )
+    {
+        return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
+    }
+
+    /* pointers to temporaries within memory pool */
+    mbedtls_mpi_uint *const Wtable  = mempool;
+    mbedtls_mpi_uint *const Wselect = Wtable    + table_limbs;
+    mbedtls_mpi_uint *const temp    = Wselect  + select_limbs;
+
+    /*
+     * Window precomputation
+     */
+
+    const mbedtls_mpi_uint mm = mbedtls_mpi_core_montmul_init( N );
+
+    /* Set Wtable[i] = A^(2^i) (in Montgomery representation) */
+    exp_mod_precompute_window( A, N, AN_limbs,
+                               mm, RR,
+                               welem, Wtable, temp );
+
+    /*
+     * Fixed window exponentiation
+     */
+
+    /* X = 1 (in Montgomery presentation) initially */
+    memcpy( X, Wtable, AN_limbs * ciL );
+
+    /* We'll process the bits of E from most significant
+     * (limb_index=E_limbs-1, E_bit_index=biL-1) to least significant
+     * (limb_index=0, E_bit_index=0). */
+    size_t E_limb_index = E_limbs;
+    size_t E_bit_index = 0;
+    /* At any given time, window contains window_bits bits from E.
+     * window_bits can go up to wsize. */
+    size_t window_bits = 0;
+    mbedtls_mpi_uint window = 0;
+
+    do
+    {
+        /* Square */
+        mbedtls_mpi_core_montmul( X, X, X, AN_limbs, N, AN_limbs, mm, temp );
+
+        /* Move to the next bit of the exponent */
+        if( E_bit_index == 0 )
+        {
+            --E_limb_index;
+            E_bit_index = biL - 1;
+        }
+        else
+        {
+            --E_bit_index;
+        }
+        /* Insert next exponent bit into window */
+        ++window_bits;
+        window <<= 1;
+        window |= ( E[E_limb_index] >> E_bit_index ) & 1;
+
+        /* Clear window if it's full. Also clear the window at the end,
+         * when we've finished processing the exponent. */
+        if( window_bits == wsize ||
+            ( E_bit_index == 0 && E_limb_index == 0 ) )
+        {
+            /* Select Wtable[window] without leaking window through
+             * memory access patterns. */
+            mbedtls_mpi_core_ct_uint_table_lookup( Wselect, Wtable,
+                                                   AN_limbs, welem, window );
+            /* Multiply X by the selected element. */
+            mbedtls_mpi_core_montmul( X, X, Wselect, AN_limbs, N, AN_limbs, mm,
+                                      temp );
+            window = 0;
+            window_bits = 0;
+        }
+    }
+    while( ! ( E_bit_index == 0 && E_limb_index == 0 ) );
+
+    /* Convert X back to normal presentation */
+    const mbedtls_mpi_uint one = 1;
+    mbedtls_mpi_core_montmul( X, X, &one, 1, N, AN_limbs, mm, temp );
+
+    mbedtls_platform_zeroize( mempool, total_limbs * sizeof(mbedtls_mpi_uint) );
+    mbedtls_free( mempool );
+    return( 0 );
+}
+
 /* END MERGE SLOT 1 */
 
 /* BEGIN MERGE SLOT 2 */
diff --git a/library/bignum_core.h b/library/bignum_core.h
index d48e705..24559c6 100644
--- a/library/bignum_core.h
+++ b/library/bignum_core.h
@@ -452,6 +452,7 @@
 int mbedtls_mpi_core_get_mont_r2_unsafe( mbedtls_mpi *X,
                                          const mbedtls_mpi *N );
 
+#if defined(MBEDTLS_TEST_HOOKS)
 /**
  * Copy an MPI from a table without leaking the index.
  *
@@ -469,6 +470,7 @@
                                             size_t limbs,
                                             size_t count,
                                             size_t index );
+#endif /* MBEDTLS_TEST_HOOKS */
 
 /**
  * \brief          Fill an integer with a number of random bytes.
@@ -496,6 +498,29 @@
 
 /* BEGIN MERGE SLOT 1 */
 
+/**
+ * \brief          Perform a modular exponentiation with secret exponent:
+ *                 X = A^E mod N
+ *
+ * \param[out] X   The destination MPI, as a little endian array of length
+ *                 \p AN_limbs.
+ * \param[in] A    The base MPI, as a little endian array of length \p AN_limbs.
+ * \param[in] N    The modulus, as a little endian array of length \p AN_limbs.
+ * \param AN_limbs The number of limbs in \p X, \p A, \p N, \p RR.
+ * \param[in] E    The exponent, as a little endian array of length \p E_limbs.
+ * \param E_limbs  The number of limbs in \p E.
+ * \param[in] RR   The precomputed residue of 2^{2*biL} modulo N, as a little
+ *                 endian array of length \p AN_limbs.
+ *
+ * \return         \c 0 if successful.
+ * \return         #MBEDTLS_ERR_MPI_ALLOC_FAILED if a memory allocation failed.
+ */
+int mbedtls_mpi_core_exp_mod( mbedtls_mpi_uint *X,
+                              const mbedtls_mpi_uint *A,
+                              const mbedtls_mpi_uint *N, size_t AN_limbs,
+                              const mbedtls_mpi_uint *E, size_t E_limbs,
+                              const mbedtls_mpi_uint *RR );
+
 /* END MERGE SLOT 1 */
 
 /* BEGIN MERGE SLOT 2 */
diff --git a/scripts/mbedtls_dev/bignum_core.py b/scripts/mbedtls_dev/bignum_core.py
index b8e2a31..2960d24 100644
--- a/scripts/mbedtls_dev/bignum_core.py
+++ b/scripts/mbedtls_dev/bignum_core.py
@@ -755,6 +755,23 @@
 
 # BEGIN MERGE SLOT 1
 
+class BignumCoreExpMod(BignumCoreTarget, bignum_common.ModOperationCommon):
+    """Test cases for bignum core exponentiation."""
+    symbol = "^"
+    test_function = "mpi_core_exp_mod"
+    test_name = "Core modular exponentiation"
+    input_style = "fixed"
+
+    def result(self) -> List[str]:
+        result = pow(self.int_a, self.int_b, self.int_n)
+        return [self.format_result(result)]
+
+    @property
+    def is_valid(self) -> bool:
+        # The base needs to be canonical, but the exponent can be larger than
+        # the modulus (see for example exponent blinding)
+        return bool(self.int_a < self.int_n)
+
 # END MERGE SLOT 1
 
 # BEGIN MERGE SLOT 2
diff --git a/tests/suites/test_suite_bignum_core.function b/tests/suites/test_suite_bignum_core.function
index d5bb420..078239f 100644
--- a/tests/suites/test_suite_bignum_core.function
+++ b/tests/suites/test_suite_bignum_core.function
@@ -935,7 +935,7 @@
 }
 /* END_CASE */
 
-/* BEGIN_CASE */
+/* BEGIN_CASE depends_on:MBEDTLS_TEST_HOOKS */
 void mpi_core_ct_uint_table_lookup( int bitlen, int window_size )
 {
     size_t limbs = BITS_TO_LIMBS( bitlen );
@@ -1041,6 +1041,59 @@
 
 /* BEGIN MERGE SLOT 1 */
 
+/* BEGIN_CASE */
+void mpi_core_exp_mod( char * input_N, char * input_A,
+                       char * input_E, char * input_X )
+{
+    mbedtls_mpi_uint *A = NULL;
+    size_t A_limbs;
+    mbedtls_mpi_uint *E = NULL;
+    size_t E_limbs;
+    mbedtls_mpi_uint *N = NULL;
+    size_t N_limbs;
+    mbedtls_mpi_uint *X = NULL;
+    size_t X_limbs;
+    const mbedtls_mpi_uint *R2 = NULL;
+    mbedtls_mpi_uint *Y = NULL;
+    /* Legacy MPIs for computing R2 */
+    mbedtls_mpi N_mpi;
+    mbedtls_mpi_init( &N_mpi );
+    mbedtls_mpi R2_mpi;
+    mbedtls_mpi_init( &R2_mpi );
+
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &A, &A_limbs, input_A ) );
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &E, &E_limbs, input_E ) );
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &N, &N_limbs, input_N ) );
+    TEST_EQUAL( 0, mbedtls_test_read_mpi_core( &X, &X_limbs, input_X ) );
+    ASSERT_ALLOC( Y, N_limbs );
+
+    TEST_EQUAL( A_limbs, N_limbs );
+    TEST_EQUAL( X_limbs, N_limbs );
+
+    TEST_EQUAL( 0, mbedtls_mpi_grow( &N_mpi, N_limbs ) );
+    memcpy( N_mpi.p, N, N_limbs * sizeof( *N ) );
+    N_mpi.n = N_limbs;
+    TEST_EQUAL( 0,
+                mbedtls_mpi_core_get_mont_r2_unsafe( &R2_mpi, &N_mpi ) );
+    TEST_EQUAL( 0, mbedtls_mpi_grow( &R2_mpi, N_limbs ) );
+    R2 = R2_mpi.p;
+
+    TEST_EQUAL( 0,
+                mbedtls_mpi_core_exp_mod( Y, A, N, N_limbs, E, E_limbs, R2 ) );
+    TEST_EQUAL( 0, memcmp( X, Y, N_limbs * sizeof( mbedtls_mpi_uint ) ) );
+
+exit:
+    mbedtls_free( A );
+    mbedtls_free( E );
+    mbedtls_free( N );
+    mbedtls_free( X );
+    mbedtls_free( Y );
+    mbedtls_mpi_free( &N_mpi );
+    mbedtls_mpi_free( &R2_mpi );
+    // R2 doesn't need to be freed as it is only aliasing R2_mpi
+}
+/* END_CASE */
+
 /* END MERGE SLOT 1 */
 
 /* BEGIN MERGE SLOT 2 */