Merge pull request #6731 from tom-cosgrove-arm/issue-6293-mod_exp

Require input to mbedtls_mpi_core_exp_mod() to already be in Montgomery form
diff --git a/library/bignum_core.c b/library/bignum_core.c
index 6635351..2ed907d 100644
--- a/library/bignum_core.c
+++ b/library/bignum_core.c
@@ -610,9 +610,9 @@
     Wtable[0] = 1;
     mbedtls_mpi_core_montmul( Wtable, Wtable, RR, AN_limbs, N, AN_limbs, mm, temp );
 
-    /* W[1] = A * R^2 * R^-1 mod N = A * R mod N */
+    /* W[1] = A (already in Montgomery presentation) */
     mbedtls_mpi_uint *W1 = Wtable + AN_limbs;
-    mbedtls_mpi_core_montmul( W1, A, RR, AN_limbs, N, AN_limbs, mm, temp );
+    memcpy( W1, A, AN_limbs * ciL );
 
     /* W[i+1] = W[i] * W[1], i >= 2 */
     mbedtls_mpi_uint *Wprev = W1;
@@ -626,6 +626,8 @@
 
 /* Exponentiation: X := A^E mod N.
  *
+ * A must already be in Montgomery form.
+ *
  * As in other bignum functions, assume that AN_limbs and E_limbs are nonzero.
  *
  * RR must contain 2^{2*biL} mod N.
@@ -730,10 +732,6 @@
     }
     while( ! ( E_bit_index == 0 && E_limb_index == 0 ) );
 
-    /* Convert X back to normal presentation */
-    const mbedtls_mpi_uint one = 1;
-    mbedtls_mpi_core_montmul( X, X, &one, 1, N, AN_limbs, mm, temp );
-
     mbedtls_platform_zeroize( mempool, total_limbs * sizeof(mbedtls_mpi_uint) );
     mbedtls_free( mempool );
     return( 0 );
diff --git a/library/bignum_core.h b/library/bignum_core.h
index 24559c6..3348564 100644
--- a/library/bignum_core.h
+++ b/library/bignum_core.h
@@ -500,11 +500,12 @@
 
 /**
  * \brief          Perform a modular exponentiation with secret exponent:
- *                 X = A^E mod N
+ *                 X = A^E mod N, where \p A is already in Montgomery form.
  *
  * \param[out] X   The destination MPI, as a little endian array of length
  *                 \p AN_limbs.
  * \param[in] A    The base MPI, as a little endian array of length \p AN_limbs.
+ *                 Must be in Montgomery form.
  * \param[in] N    The modulus, as a little endian array of length \p AN_limbs.
  * \param AN_limbs The number of limbs in \p X, \p A, \p N, \p RR.
  * \param[in] E    The exponent, as a little endian array of length \p E_limbs.
diff --git a/scripts/mbedtls_dev/bignum_common.py b/scripts/mbedtls_dev/bignum_common.py
index 67ea78d..3ff8b2f 100644
--- a/scripts/mbedtls_dev/bignum_common.py
+++ b/scripts/mbedtls_dev/bignum_common.py
@@ -251,6 +251,12 @@
         # provides earlier/more robust input validation.
         self.int_n = hex_to_int(val_n)
 
+    def to_montgomery(self, val: int) -> int:
+        return (val * self.r) % self.int_n
+
+    def from_montgomery(self, val: int) -> int:
+        return (val * self.r_inv) % self.int_n
+
     @property
     def boundary(self) -> int:
         return self.int_n
diff --git a/scripts/mbedtls_dev/bignum_core.py b/scripts/mbedtls_dev/bignum_core.py
index 2960d24..118a659 100644
--- a/scripts/mbedtls_dev/bignum_core.py
+++ b/scripts/mbedtls_dev/bignum_core.py
@@ -759,12 +759,23 @@
     """Test cases for bignum core exponentiation."""
     symbol = "^"
     test_function = "mpi_core_exp_mod"
-    test_name = "Core modular exponentiation"
+    test_name = "Core modular exponentiation (Mongtomery form only)"
     input_style = "fixed"
 
+    def arguments(self) -> List[str]:
+        # Input 'a' has to be given in Montgomery form
+        mont_a = self.to_montgomery(self.int_a)
+        arg_mont_a = self.format_arg('{:x}'.format(mont_a))
+        return [bignum_common.quote_str(n) for n in [self.arg_n,
+                                                     arg_mont_a,
+                                                     self.arg_b]
+               ] + self.result()
+
     def result(self) -> List[str]:
+        # Result has to be given in Montgomery form too
         result = pow(self.int_a, self.int_b, self.int_n)
-        return [self.format_result(result)]
+        mont_result = self.to_montgomery(result)
+        return [self.format_result(mont_result)]
 
     @property
     def is_valid(self) -> bool:
diff --git a/scripts/mbedtls_dev/bignum_mod_raw.py b/scripts/mbedtls_dev/bignum_mod_raw.py
index 0bbad5d..d05479a 100644
--- a/scripts/mbedtls_dev/bignum_mod_raw.py
+++ b/scripts/mbedtls_dev/bignum_mod_raw.py
@@ -92,10 +92,9 @@
     arity = 1
 
     def result(self) -> List[str]:
-        result = (self.int_a * self.r) % self.int_n
+        result = self.to_montgomery(self.int_a)
         return [self.format_result(result)]
 
-
 class BignumModRawConvertFromMont(bignum_common.ModOperationCommon,
                                   BignumModRawTarget):
     """ Test cases for mpi_mod_raw_from_mont_rep(). """
@@ -106,7 +105,7 @@
     arity = 1
 
     def result(self) -> List[str]:
-        result = (self.int_a * self.r_inv) % self.int_n
+        result = self.from_montgomery(self.int_a)
         return [self.format_result(result)]