Merge pull request #6704 from tom-cosgrove-arm/issue-6020-mod_sub

Implement mbedtls_mpi_mod_sub()
diff --git a/library/bignum_mod.c b/library/bignum_mod.c
index 7a5539d..7cf2fb2 100644
--- a/library/bignum_mod.c
+++ b/library/bignum_mod.c
@@ -179,7 +179,18 @@
 /* END MERGE SLOT 2 */
 
 /* BEGIN MERGE SLOT 3 */
+int mbedtls_mpi_mod_sub( mbedtls_mpi_mod_residue *X,
+                         const mbedtls_mpi_mod_residue *A,
+                         const mbedtls_mpi_mod_residue *B,
+                         const mbedtls_mpi_mod_modulus *N )
+{
+    if( X->limbs != N->limbs || A->limbs != N->limbs || B->limbs != N->limbs )
+        return( MBEDTLS_ERR_MPI_BAD_INPUT_DATA );
 
+    mbedtls_mpi_mod_raw_sub( X->p, A->p, B->p, N );
+
+    return( 0 );
+}
 /* END MERGE SLOT 3 */
 
 /* BEGIN MERGE SLOT 4 */
diff --git a/library/bignum_mod.h b/library/bignum_mod.h
index d92f21e..0a8f4d3 100644
--- a/library/bignum_mod.h
+++ b/library/bignum_mod.h
@@ -163,7 +163,35 @@
 /* END MERGE SLOT 2 */
 
 /* BEGIN MERGE SLOT 3 */
-
+/**
+ * \brief Perform a fixed-size modular subtraction.
+ *
+ * Calculate `A - B modulo N`.
+ *
+ * \p A, \p B and \p X must all have the same number of limbs as \p N.
+ *
+ * \p X may be aliased to \p A or \p B, or even both, but may not overlap
+ * either otherwise.
+ *
+ * \note This function does not check that \p A or \p B are in canonical
+ *       form (that is, are < \p N) - that will have been done by
+ *       mbedtls_mpi_mod_residue_setup().
+ *
+ * \param[out] X    The address of the result MPI. Must be initialized.
+ *                  Must have the same number of limbs as the modulus \p N.
+ * \param[in]  A    The address of the first MPI.
+ * \param[in]  B    The address of the second MPI.
+ * \param[in]  N    The address of the modulus. Used to perform a modulo
+ *                  operation on the result of the subtraction.
+ *
+ * \return          \c 0 if successful.
+ * \return          #MBEDTLS_ERR_MPI_BAD_INPUT_DATA if the given MPIs do not
+ *                  have the correct number of limbs.
+ */
+int mbedtls_mpi_mod_sub( mbedtls_mpi_mod_residue *X,
+                         const mbedtls_mpi_mod_residue *A,
+                         const mbedtls_mpi_mod_residue *B,
+                         const mbedtls_mpi_mod_modulus *N );
 /* END MERGE SLOT 3 */
 
 /* BEGIN MERGE SLOT 4 */
diff --git a/scripts/mbedtls_dev/bignum_mod.py b/scripts/mbedtls_dev/bignum_mod.py
index 81ece07..aa06fe8 100644
--- a/scripts/mbedtls_dev/bignum_mod.py
+++ b/scripts/mbedtls_dev/bignum_mod.py
@@ -34,6 +34,20 @@
 
 # BEGIN MERGE SLOT 3
 
+class BignumModSub(bignum_common.ModOperationCommon, BignumModTarget):
+    """Test cases for bignum mpi_mod_sub()."""
+    symbol = "-"
+    test_function = "mpi_mod_sub"
+    test_name = "mbedtls_mpi_mod_sub"
+    input_style = "fixed"
+    arity = 2
+
+    def result(self) -> List[str]:
+        result = (self.int_a - self.int_b) % self.int_n
+        # To make negative tests easier, append 0 for success to the
+        # generated cases
+        return [self.format_result(result), "0"]
+
 # END MERGE SLOT 3
 
 # BEGIN MERGE SLOT 4
diff --git a/tests/suites/test_suite_bignum_mod.data b/tests/suites/test_suite_bignum_mod.data
index 2ea4a58..501d9d7 100644
--- a/tests/suites/test_suite_bignum_mod.data
+++ b/tests/suites/test_suite_bignum_mod.data
@@ -17,6 +17,27 @@
 
 # BEGIN MERGE SLOT 3
 
+mpi_mod_sub base case for negative testing (N, a, b all >= 1 limb)
+mpi_mod_sub:"014320a022ccb75bdf470ddf25":"000000025a55a46e5da99c71c7":"00033b2e3c9fd0803ce8000f93":"013fe57440828b4a0008aa4159":0
+
+mpi_mod_sub with modulus too long/both inputs too short
+mpi_mod_sub:"0000000014320a022ccb75bdf470ddf25":"000000025a55a46e5da99c71c7":"00033b2e3c9fd0803ce8000f93":"00":MBEDTLS_ERR_MPI_BAD_INPUT_DATA
+
+mpi_mod_sub with first input too long
+mpi_mod_sub:"014320a022ccb75bdf470ddf25":"0000000000000025a55a46e5da99c71c7":"00033b2e3c9fd0803ce8000f93":"00":MBEDTLS_ERR_MPI_BAD_INPUT_DATA
+
+mpi_mod_sub with second input too long
+mpi_mod_sub:"014320a022ccb75bdf470ddf25":"000000025a55a46e5da99c71c7":"000000000033b2e3c9fd0803ce8000f93":"00":MBEDTLS_ERR_MPI_BAD_INPUT_DATA
+
+mpi_mod_sub with both inputs too long
+mpi_mod_sub:"014320a022ccb75bdf470ddf25":"0000000000000025a55a46e5da99c71c7":"000000000033b2e3c9fd0803ce8000f93":"00":MBEDTLS_ERR_MPI_BAD_INPUT_DATA
+
+mpi_mod_sub with first input too short
+mpi_mod_sub:"014320a022ccb75bdf470ddf25":"a99c71c7":"00033b2e3c9fd0803ce8000f93":"00":MBEDTLS_ERR_MPI_BAD_INPUT_DATA
+
+mpi_mod_sub with second input too short
+mpi_mod_sub:"014320a022ccb75bdf470ddf25":"000000025a55a46e5da99c71c7":"e8000f93":"00":MBEDTLS_ERR_MPI_BAD_INPUT_DATA
+
 # END MERGE SLOT 3
 
 # BEGIN MERGE SLOT 4
diff --git a/tests/suites/test_suite_bignum_mod.function b/tests/suites/test_suite_bignum_mod.function
index a941cb6..0d2e232 100644
--- a/tests/suites/test_suite_bignum_mod.function
+++ b/tests/suites/test_suite_bignum_mod.function
@@ -4,6 +4,47 @@
 #include "bignum_mod.h"
 #include "constant_time_internal.h"
 #include "test/constant_flow.h"
+
+#define TEST_COMPARE_MPI_RESIDUES( a, b ) \
+            ASSERT_COMPARE( (a).p, (a).limbs * sizeof(mbedtls_mpi_uint), \
+                            (b).p, (b).limbs * sizeof(mbedtls_mpi_uint) )
+
+static int test_read_modulus( mbedtls_mpi_mod_modulus *m,
+                              mbedtls_mpi_mod_rep_selector int_rep,
+                              char *input )
+{
+    mbedtls_mpi_uint *p = NULL;
+    size_t limbs;
+
+    int ret = mbedtls_test_read_mpi_core( &p, &limbs, input );
+    if( ret != 0 )
+        return( ret );
+
+    return( mbedtls_mpi_mod_modulus_setup( m, p, limbs, int_rep ) );
+}
+
+static int test_read_residue( mbedtls_mpi_mod_residue *r,
+                              const mbedtls_mpi_mod_modulus *m,
+                              char *input,
+                              int skip_limbs_and_value_checks )
+{
+    mbedtls_mpi_uint *p = NULL;
+    size_t limbs;
+
+    int ret = mbedtls_test_read_mpi_core( &p, &limbs, input );
+    if( ret != 0 )
+        return( ret );
+
+    if( skip_limbs_and_value_checks )
+    {
+        r->p = p;
+        r->limbs = limbs;
+        return( 0 );
+    }
+
+    /* mbedtls_mpi_mod_residue_setup() checks limbs, and that value < m */
+    return( mbedtls_mpi_mod_residue_setup( r, m, p, limbs ) );
+}
 /* END_HEADER */
 
 /* BEGIN_DEPENDENCIES
@@ -64,7 +105,104 @@
 /* END MERGE SLOT 2 */
 
 /* BEGIN MERGE SLOT 3 */
+/* BEGIN_CASE */
+void mpi_mod_sub( char * input_N,
+                  char * input_A, char * input_B,
+                  char * input_D, int oret )
+{
+    mbedtls_mpi_mod_residue a = { NULL, 0 };
+    mbedtls_mpi_mod_residue b = { NULL, 0 };
+    mbedtls_mpi_mod_residue d = { NULL, 0 };
+    mbedtls_mpi_mod_residue x = { NULL, 0 };
+    mbedtls_mpi_uint *X_raw = NULL;
 
+    mbedtls_mpi_mod_modulus m;
+    mbedtls_mpi_mod_modulus_init( &m );
+
+    TEST_EQUAL( 0,
+        test_read_modulus( &m, MBEDTLS_MPI_MOD_REP_MONTGOMERY, input_N ) );
+
+    /* test_read_residue() normally checks that inputs have the same number of
+     * limbs as the modulus. For negative testing we can ask it to skip this
+     * with a non-zero final parameter. */
+    TEST_EQUAL( 0, test_read_residue( &a, &m, input_A, oret != 0 ) );
+    TEST_EQUAL( 0, test_read_residue( &b, &m, input_B, oret != 0 ) );
+    TEST_EQUAL( 0, test_read_residue( &d, &m, input_D, oret != 0 ) );
+
+    size_t limbs = m.limbs;
+    size_t bytes = limbs * sizeof( *X_raw );
+
+    /* One spare limb for negative testing */
+    ASSERT_ALLOC( X_raw, limbs + 1 );
+
+    if( oret == 0 )
+    {
+        /* Sneak in a couple of negative tests on known-good data */
+
+        /* First, negative test with too many limbs in output */
+        x.p = X_raw;
+        x.limbs = limbs + 1;
+        TEST_EQUAL( MBEDTLS_ERR_MPI_BAD_INPUT_DATA,
+                    mbedtls_mpi_mod_sub( &x, &a, &b, &m ) );
+
+        /* Then negative test with too few limbs in output */
+        if( limbs > 1 )
+        {
+            x.p = X_raw;
+            x.limbs = limbs - 1;
+            TEST_EQUAL( MBEDTLS_ERR_MPI_BAD_INPUT_DATA,
+                        mbedtls_mpi_mod_sub( &x, &a, &b, &m ) );
+        }
+
+        /* Negative testing with too many/too few limbs in a and b is covered by
+         * manually-written test cases with oret != 0. */
+
+        /* Back to the normally-scheduled programme */
+    }
+
+    TEST_EQUAL( 0, mbedtls_mpi_mod_residue_setup( &x, &m, X_raw, limbs ) );
+
+    /* a - b => Correct result, or expected error */
+    TEST_EQUAL( oret, mbedtls_mpi_mod_sub( &x, &a, &b, &m ) );
+    if( oret != 0 )
+        goto exit;
+
+    TEST_COMPARE_MPI_RESIDUES( x, d );
+
+    /* a - b: alias x to a => Correct result */
+    memcpy( x.p, a.p, bytes );
+    TEST_EQUAL( 0, mbedtls_mpi_mod_sub( &x, &x, &b, &m ) );
+    TEST_COMPARE_MPI_RESIDUES( x, d );
+
+    /* a - b: alias x to b => Correct result */
+    memcpy( x.p, b.p, bytes );
+    TEST_EQUAL( 0, mbedtls_mpi_mod_sub( &x, &a, &x, &m ) );
+    TEST_COMPARE_MPI_RESIDUES( x, d );
+
+    if ( memcmp( a.p, b.p, bytes ) == 0 )
+    {
+        /* a == b: alias a and b */
+
+        /* a - a => Correct result */
+        TEST_EQUAL( 0, mbedtls_mpi_mod_sub( &x, &a, &a, &m ) );
+        TEST_COMPARE_MPI_RESIDUES( x, d );
+
+        /* a - a: x, a, b all aliased together => Correct result */
+        memcpy( x.p, a.p, bytes );
+        TEST_EQUAL( 0, mbedtls_mpi_mod_sub( &x, &x, &x, &m ) );
+        TEST_COMPARE_MPI_RESIDUES( x, d );
+    }
+
+exit:
+    mbedtls_free( (void *)m.p ); /* mbedtls_mpi_mod_modulus_free() sets m.p = NULL */
+    mbedtls_mpi_mod_modulus_free( &m );
+
+    mbedtls_free( a.p );
+    mbedtls_free( b.p );
+    mbedtls_free( d.p );
+    mbedtls_free( X_raw );
+}
+/* END_CASE */
 /* END MERGE SLOT 3 */
 
 /* BEGIN MERGE SLOT 4 */