Have mbedtls_mpi_core_exp_mod() take a temporary instead of allocating memory

Last PR needed for #6293

Signed-off-by: Tom Cosgrove <tom.cosgrove@arm.com>
diff --git a/library/bignum_core.c b/library/bignum_core.c
index 2ed907d..30fc661 100644
--- a/library/bignum_core.c
+++ b/library/bignum_core.c
@@ -596,6 +596,19 @@
     return( wsize );
 }
 
+size_t mbedtls_mpi_core_exp_mod_working_limbs( size_t AN_limbs, size_t E_limbs )
+{
+    const size_t wsize = exp_mod_get_window_size( E_limbs * biL );
+    const size_t welem = ( (size_t) 1 ) << wsize;
+
+    /* How big does each part of the working memory pool need to be? */
+    const size_t table_limbs   = welem * AN_limbs;
+    const size_t select_limbs  = AN_limbs;
+    const size_t temp_limbs    = 2 * AN_limbs + 1;
+
+    return( table_limbs + select_limbs + temp_limbs );
+}
+
 static void exp_mod_precompute_window( const mbedtls_mpi_uint *A,
                                        const mbedtls_mpi_uint *N,
                                        size_t AN_limbs,
@@ -636,35 +649,27 @@
  * (The difference is that the body in our loop processes a single bit instead
  * of a full window.)
  */
-int mbedtls_mpi_core_exp_mod( mbedtls_mpi_uint *X,
-                              const mbedtls_mpi_uint *A,
-                              const mbedtls_mpi_uint *N,
-                              size_t AN_limbs,
-                              const mbedtls_mpi_uint *E,
-                              size_t E_limbs,
-                              const mbedtls_mpi_uint *RR )
+void mbedtls_mpi_core_exp_mod( mbedtls_mpi_uint *X,
+                               const mbedtls_mpi_uint *A,
+                               const mbedtls_mpi_uint *N,
+                               size_t AN_limbs,
+                               const mbedtls_mpi_uint *E,
+                               size_t E_limbs,
+                               const mbedtls_mpi_uint *RR,
+                               mbedtls_mpi_uint *T )
 {
     const size_t wsize = exp_mod_get_window_size( E_limbs * biL );
     const size_t welem = ( (size_t) 1 ) << wsize;
 
-    /* Allocate memory pool and set pointers to parts of it */
-    const size_t table_limbs   = welem * AN_limbs;
-    const size_t temp_limbs    = 2 * AN_limbs + 1;
-    const size_t select_limbs  = AN_limbs;
-    const size_t total_limbs   = table_limbs + temp_limbs + select_limbs;
+    /* This is how we will use the temporary storage T, which must have space
+     * for table_limbs, select_limbs and (2 * AN_limbs + 1) for montmul. */
+    const size_t table_limbs  = welem * AN_limbs;
+    const size_t select_limbs = AN_limbs;
 
-    /* heap allocated memory pool */
-    mbedtls_mpi_uint *mempool =
-        mbedtls_calloc( total_limbs, sizeof(mbedtls_mpi_uint) );
-    if( mempool == NULL )
-    {
-        return( MBEDTLS_ERR_MPI_ALLOC_FAILED );
-    }
-
-    /* pointers to temporaries within memory pool */
-    mbedtls_mpi_uint *const Wtable  = mempool;
-    mbedtls_mpi_uint *const Wselect = Wtable    + table_limbs;
-    mbedtls_mpi_uint *const temp    = Wselect  + select_limbs;
+    /* Pointers to specific parts of the temporary working memory pool */
+    mbedtls_mpi_uint *const Wtable  = T;
+    mbedtls_mpi_uint *const Wselect = Wtable  +  table_limbs;
+    mbedtls_mpi_uint *const temp    = Wselect + select_limbs;
 
     /*
      * Window precomputation
@@ -731,10 +736,6 @@
         }
     }
     while( ! ( E_bit_index == 0 && E_limb_index == 0 ) );
-
-    mbedtls_platform_zeroize( mempool, total_limbs * sizeof(mbedtls_mpi_uint) );
-    mbedtls_free( mempool );
-    return( 0 );
 }
 
 /* END MERGE SLOT 1 */
diff --git a/library/bignum_core.h b/library/bignum_core.h
index 3348564..add7fee 100644
--- a/library/bignum_core.h
+++ b/library/bignum_core.h
@@ -499,28 +499,50 @@
 /* BEGIN MERGE SLOT 1 */
 
 /**
- * \brief          Perform a modular exponentiation with secret exponent:
- *                 X = A^E mod N, where \p A is already in Montgomery form.
+ * \brief          Returns the number of limbs of working memory required for
+ *                 a call to `mbedtls_mpi_core_exp_mod()`.
  *
- * \param[out] X   The destination MPI, as a little endian array of length
- *                 \p AN_limbs.
- * \param[in] A    The base MPI, as a little endian array of length \p AN_limbs.
- *                 Must be in Montgomery form.
- * \param[in] N    The modulus, as a little endian array of length \p AN_limbs.
- * \param AN_limbs The number of limbs in \p X, \p A, \p N, \p RR.
- * \param[in] E    The exponent, as a little endian array of length \p E_limbs.
- * \param E_limbs  The number of limbs in \p E.
- * \param[in] RR   The precomputed residue of 2^{2*biL} modulo N, as a little
- *                 endian array of length \p AN_limbs.
+ * \param AN_limbs The number of limbs in the input `A` and the modulus `N`
+ *                 (they must be the same size) that will be given to
+ *                 `mbedtls_mpi_core_exp_mod()`.
+ * \param E_limbs  The number of limbs in the exponent `E` that will be given
+ *                 to `mbedtls_mpi_core_exp_mod()`.
  *
- * \return         \c 0 if successful.
- * \return         #MBEDTLS_ERR_MPI_ALLOC_FAILED if a memory allocation failed.
+ * \return         The number of limbs of working memory required by
+ *                 `mbedtls_mpi_core_exp_mod()`.
  */
-int mbedtls_mpi_core_exp_mod( mbedtls_mpi_uint *X,
-                              const mbedtls_mpi_uint *A,
-                              const mbedtls_mpi_uint *N, size_t AN_limbs,
-                              const mbedtls_mpi_uint *E, size_t E_limbs,
-                              const mbedtls_mpi_uint *RR );
+size_t mbedtls_mpi_core_exp_mod_working_limbs( size_t AN_limbs, size_t E_limbs );
+
+/**
+ * \brief            Perform a modular exponentiation with secret exponent:
+ *                   X = A^E mod N, where \p A is already in Montgomery form.
+ *
+ * \param[out] X     The destination MPI, as a little endian array of length
+ *                   \p AN_limbs.
+ * \param[in] A      The base MPI, as a little endian array of length \p AN_limbs.
+ *                   Must be in Montgomery form.
+ * \param[in] N      The modulus, as a little endian array of length \p AN_limbs.
+ * \param AN_limbs   The number of limbs in \p X, \p A, \p N, \p RR.
+ * \param[in] E      The exponent, as a little endian array of length \p E_limbs.
+ * \param E_limbs    The number of limbs in \p E.
+ * \param[in] RR     The precomputed residue of 2^{2*biL} modulo N, as a little
+ *                   endian array of length \p AN_limbs.
+ * \param[in,out] T  Temporary storage of at least the number of limbs returned
+ *                   by `mbedtls_mpi_core_exp_mod_working_limbs()`.
+ *                   Its initial content is unused and its final content is
+ *                   indeterminate.
+ *                   It must not alias or otherwise overlap any of the other
+ *                   parameters.
+ *                   It is up to the caller to zeroize \p T when it is no
+ *                   longer needed, and before freeing it if it was dynamically
+ *                   allocated.
+ */
+void mbedtls_mpi_core_exp_mod( mbedtls_mpi_uint *X,
+                               const mbedtls_mpi_uint *A,
+                               const mbedtls_mpi_uint *N, size_t AN_limbs,
+                               const mbedtls_mpi_uint *E, size_t E_limbs,
+                               const mbedtls_mpi_uint *RR,
+                               mbedtls_mpi_uint *T );
 
 /* END MERGE SLOT 1 */
 
diff --git a/tests/suites/test_suite_bignum_core.function b/tests/suites/test_suite_bignum_core.function
index 078239f..b64127a 100644
--- a/tests/suites/test_suite_bignum_core.function
+++ b/tests/suites/test_suite_bignum_core.function
@@ -1046,15 +1046,13 @@
                        char * input_E, char * input_X )
 {
     mbedtls_mpi_uint *A = NULL;
-    size_t A_limbs;
     mbedtls_mpi_uint *E = NULL;
-    size_t E_limbs;
     mbedtls_mpi_uint *N = NULL;
-    size_t N_limbs;
     mbedtls_mpi_uint *X = NULL;
-    size_t X_limbs;
+    size_t A_limbs, E_limbs, N_limbs, X_limbs;
     const mbedtls_mpi_uint *R2 = NULL;
     mbedtls_mpi_uint *Y = NULL;
+    mbedtls_mpi_uint *T = NULL;
     /* Legacy MPIs for computing R2 */
     mbedtls_mpi N_mpi;
     mbedtls_mpi_init( &N_mpi );
@@ -1078,11 +1076,29 @@
     TEST_EQUAL( 0, mbedtls_mpi_grow( &R2_mpi, N_limbs ) );
     R2 = R2_mpi.p;
 
-    TEST_EQUAL( 0,
-                mbedtls_mpi_core_exp_mod( Y, A, N, N_limbs, E, E_limbs, R2 ) );
+    size_t working_limbs = mbedtls_mpi_core_exp_mod_working_limbs( N_limbs,
+                                                                   E_limbs );
+
+    /* No point exactly duplicating the code in mbedtls_mpi_core_exp_mod_working_limbs()
+     * to see if the output is correct, but we can check that it's in a
+     * reasonable range.  The current calculation works out as
+     * `1 + N_limbs * (welem + 3)`, where welem is the number of elements in
+     * the window (1 << 1 up to 1 << 6).
+     */
+    size_t min_expected_working_limbs = 1 + N_limbs * 4;
+    size_t max_expected_working_limbs = 1 + N_limbs * 67;
+
+    TEST_LE_U( min_expected_working_limbs, working_limbs );
+    TEST_LE_U( working_limbs, max_expected_working_limbs );
+
+    ASSERT_ALLOC( T, working_limbs );
+
+    mbedtls_mpi_core_exp_mod( Y, A, N, N_limbs, E, E_limbs, R2, T );
+
     TEST_EQUAL( 0, memcmp( X, Y, N_limbs * sizeof( mbedtls_mpi_uint ) ) );
 
 exit:
+    mbedtls_free( T );
     mbedtls_free( A );
     mbedtls_free( E );
     mbedtls_free( N );