Merge pull request #5155 from paul-elliott-arm/pcks12_fix

Fixes for pkcs12 with NULL and/or zero length password
diff --git a/ChangeLog.d/fix-pkcs12-null-password.txt b/ChangeLog.d/fix-pkcs12-null-password.txt
new file mode 100644
index 0000000..fae8195
--- /dev/null
+++ b/ChangeLog.d/fix-pkcs12-null-password.txt
@@ -0,0 +1,5 @@
+Bugfix
+   * Fix a potential invalid pointer dereference and infinite loop bugs in
+     pkcs12 functions when the password is empty. Fix the documentation to
+     better describe the inputs to these functions and their possible values.
+     Fixes #5136.
diff --git a/include/mbedtls/pkcs12.h b/include/mbedtls/pkcs12.h
index cded903..1b87aea 100644
--- a/include/mbedtls/pkcs12.h
+++ b/include/mbedtls/pkcs12.h
@@ -56,11 +56,13 @@
  * \brief            PKCS12 Password Based function (encryption / decryption)
  *                   for cipher-based and mbedtls_md-based PBE's
  *
- * \param pbe_params an ASN1 buffer containing the pkcs-12PbeParams structure
- * \param mode       either MBEDTLS_PKCS12_PBE_ENCRYPT or MBEDTLS_PKCS12_PBE_DECRYPT
+ * \param pbe_params an ASN1 buffer containing the pkcs-12 PbeParams structure
+ * \param mode       either #MBEDTLS_PKCS12_PBE_ENCRYPT or
+ *                   #MBEDTLS_PKCS12_PBE_DECRYPT
  * \param cipher_type the cipher used
- * \param md_type     the mbedtls_md used
- * \param pwd        the password used (may be NULL if no password is used)
+ * \param md_type    the mbedtls_md used
+ * \param pwd        Latin1-encoded password used. This may only be \c NULL when
+ *                   \p pwdlen is 0. No null terminator should be used.
  * \param pwdlen     length of the password (may be 0)
  * \param input      the input data
  * \param len        data length
@@ -81,18 +83,24 @@
  *                   to produce pseudo-random bits for a particular "purpose".
  *
  *                   Depending on the given id, this function can produce an
- *                   encryption/decryption key, an nitialization vector or an
+ *                   encryption/decryption key, an initialization vector or an
  *                   integrity key.
  *
  * \param data       buffer to store the derived data in
- * \param datalen    length to fill
- * \param pwd        password to use (may be NULL if no password is used)
- * \param pwdlen     length of the password (may be 0)
- * \param salt       salt buffer to use
- * \param saltlen    length of the salt
- * \param mbedtls_md         mbedtls_md type to use during the derivation
- * \param id         id that describes the purpose (can be MBEDTLS_PKCS12_DERIVE_KEY,
- *                   MBEDTLS_PKCS12_DERIVE_IV or MBEDTLS_PKCS12_DERIVE_MAC_KEY)
+ * \param datalen    length of buffer to fill
+ * \param pwd        The password to use. For compliance with PKCS#12 §B.1, this
+ *                   should be a BMPString, i.e. a Unicode string where each
+ *                   character is encoded as 2 bytes in big-endian order, with
+ *                   no byte order mark and with a null terminator (i.e. the
+ *                   last two bytes should be 0x00 0x00).
+ * \param pwdlen     length of the password (may be 0).
+ * \param salt       Salt buffer to use This may only be \c NULL when
+ *                   \p saltlen is 0.
+ * \param saltlen    length of the salt (may be zero)
+ * \param mbedtls_md mbedtls_md type to use during the derivation
+ * \param id         id that describes the purpose (can be
+ *                   #MBEDTLS_PKCS12_DERIVE_KEY, #MBEDTLS_PKCS12_DERIVE_IV or
+ *                   #MBEDTLS_PKCS12_DERIVE_MAC_KEY)
  * \param iterations number of iterations
  *
  * \return          0 if successful, or a MD, BIGNUM type error.
diff --git a/library/pkcs12.c b/library/pkcs12.c
index 8f64bc6..a90d1f9 100644
--- a/library/pkcs12.c
+++ b/library/pkcs12.c
@@ -134,6 +134,9 @@
     mbedtls_cipher_context_t cipher_ctx;
     size_t olen = 0;
 
+    if( pwd == NULL && pwdlen != 0 )
+        return( MBEDTLS_ERR_PKCS12_BAD_INPUT_DATA );
+
     cipher_info = mbedtls_cipher_info_from_type( cipher_type );
     if( cipher_info == NULL )
         return( MBEDTLS_ERR_PKCS12_FEATURE_UNAVAILABLE );
@@ -186,12 +189,23 @@
     unsigned char *p = data;
     size_t use_len;
 
-    while( data_len > 0 )
+    if( filler != NULL && fill_len != 0 )
     {
-        use_len = ( data_len > fill_len ) ? fill_len : data_len;
-        memcpy( p, filler, use_len );
-        p += use_len;
-        data_len -= use_len;
+        while( data_len > 0 )
+        {
+            use_len = ( data_len > fill_len ) ? fill_len : data_len;
+            memcpy( p, filler, use_len );
+            p += use_len;
+            data_len -= use_len;
+        }
+    }
+    else
+    {
+        /* If either of the above are not true then clearly there is nothing
+         * that this function can do. The function should *not* be called
+         * under either of those circumstances, as you could end up with an
+         * incorrect output but for safety's sake, leaving the check in as
+         * otherwise we could end up with memory corruption.*/
     }
 }
 
@@ -208,6 +222,8 @@
     unsigned char hash_output[MBEDTLS_MD_MAX_SIZE];
     unsigned char *p;
     unsigned char c;
+    int           use_password = 0;
+    int           use_salt = 0;
 
     size_t hlen, use_len, v, i;
 
@@ -218,6 +234,15 @@
     if( datalen > 128 || pwdlen > 64 || saltlen > 64 )
         return( MBEDTLS_ERR_PKCS12_BAD_INPUT_DATA );
 
+    if( pwd == NULL && pwdlen != 0 )
+        return( MBEDTLS_ERR_PKCS12_BAD_INPUT_DATA );
+
+    if( salt == NULL && saltlen != 0 )
+        return( MBEDTLS_ERR_PKCS12_BAD_INPUT_DATA );
+
+    use_password = ( pwd && pwdlen != 0 );
+    use_salt = ( salt && saltlen != 0 );
+
     md_info = mbedtls_md_info_from_type( md_type );
     if( md_info == NULL )
         return( MBEDTLS_ERR_PKCS12_FEATURE_UNAVAILABLE );
@@ -235,8 +260,15 @@
 
     memset( diversifier, (unsigned char) id, v );
 
-    pkcs12_fill_buffer( salt_block, v, salt, saltlen );
-    pkcs12_fill_buffer( pwd_block,  v, pwd,  pwdlen  );
+    if( use_salt != 0 )
+    {
+        pkcs12_fill_buffer( salt_block, v, salt, saltlen );
+    }
+
+    if( use_password != 0 )
+    {
+        pkcs12_fill_buffer( pwd_block,  v, pwd,  pwdlen  );
+    }
 
     p = data;
     while( datalen > 0 )
@@ -248,11 +280,17 @@
         if( ( ret = mbedtls_md_update( &md_ctx, diversifier, v ) ) != 0 )
             goto exit;
 
-        if( ( ret = mbedtls_md_update( &md_ctx, salt_block, v ) ) != 0 )
-            goto exit;
+        if( use_salt != 0 )
+        {
+            if( ( ret = mbedtls_md_update( &md_ctx, salt_block, v )) != 0 )
+                goto exit;
+        }
 
-        if( ( ret = mbedtls_md_update( &md_ctx, pwd_block, v ) ) != 0 )
-            goto exit;
+        if( use_password != 0)
+        {
+            if( ( ret = mbedtls_md_update( &md_ctx, pwd_block, v )) != 0 )
+                goto exit;
+        }
 
         if( ( ret = mbedtls_md_finish( &md_ctx, hash_output ) ) != 0 )
             goto exit;
@@ -280,22 +318,28 @@
             if( ++hash_block[i - 1] != 0 )
                 break;
 
-        // salt_block += B
-        c = 0;
-        for( i = v; i > 0; i-- )
+        if( use_salt != 0 )
         {
-            j = salt_block[i - 1] + hash_block[i - 1] + c;
-            c = MBEDTLS_BYTE_1( j );
-            salt_block[i - 1] = MBEDTLS_BYTE_0( j );
+            // salt_block += B
+            c = 0;
+            for( i = v; i > 0; i-- )
+            {
+                j = salt_block[i - 1] + hash_block[i - 1] + c;
+                c = MBEDTLS_BYTE_1( j );
+                salt_block[i - 1] = MBEDTLS_BYTE_0( j );
+            }
         }
 
-        // pwd_block  += B
-        c = 0;
-        for( i = v; i > 0; i-- )
+        if( use_password != 0 )
         {
-            j = pwd_block[i - 1] + hash_block[i - 1] + c;
-            c = MBEDTLS_BYTE_1( j );
-            pwd_block[i - 1] = MBEDTLS_BYTE_0( j );
+            // pwd_block  += B
+            c = 0;
+            for( i = v; i > 0; i-- )
+            {
+                j = pwd_block[i - 1] + hash_block[i - 1] + c;
+                c = MBEDTLS_BYTE_1( j );
+                pwd_block[i - 1] = MBEDTLS_BYTE_0( j );
+            }
         }
     }
 
diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt
index 8bb6250..bd7e3b9 100644
--- a/tests/CMakeLists.txt
+++ b/tests/CMakeLists.txt
@@ -216,6 +216,7 @@
 add_test_suite(pkcs1_v15)
 add_test_suite(pkcs1_v21)
 add_test_suite(pkcs5)
+add_test_suite(pkcs12)
 add_test_suite(pkparse)
 add_test_suite(pkwrite)
 add_test_suite(poly1305)
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index 72302f3..6e17a91 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -2856,6 +2856,36 @@
     fi
 }
 
+support_test_cmake_out_of_source () {
+    distrib_id=""
+    distrib_ver=""
+    distrib_ver_minor=""
+    distrib_ver_major=""
+
+    # Attempt to parse lsb-release to find out distribution and version. If not
+    # found this should fail safe (test is supported).
+    if [[ -f /etc/lsb-release ]]; then
+
+        while read -r lsb_line; do
+            case "$lsb_line" in
+                "DISTRIB_ID"*) distrib_id=${lsb_line/#DISTRIB_ID=};;
+                "DISTRIB_RELEASE"*) distrib_ver=${lsb_line/#DISTRIB_RELEASE=};;
+            esac
+        done < /etc/lsb-release
+
+        distrib_ver_major="${distrib_ver%%.*}"
+        distrib_ver="${distrib_ver#*.}"
+        distrib_ver_minor="${distrib_ver%%.*}"
+    fi
+
+    # Running the out of source CMake test on Ubuntu 16.04 using more than one
+    # processor (as the CI does) can create a race condition whereby the build
+    # fails to see a generated file, despite that file actually having been
+    # generated. This problem appears to go away with 18.04 or newer, so make
+    # the out of source tests unsupported on Ubuntu 16.04.
+    [ "$distrib_id" != "Ubuntu" ] || [ "$distrib_ver_major" -gt 16 ]
+}
+
 component_test_cmake_out_of_source () {
     msg "build: cmake 'out-of-source' build"
     MBEDTLS_ROOT_DIR="$PWD"
diff --git a/tests/suites/test_suite_pkcs12.data b/tests/suites/test_suite_pkcs12.data
new file mode 100644
index 0000000..a8c4bab
--- /dev/null
+++ b/tests/suites/test_suite_pkcs12.data
@@ -0,0 +1,35 @@
+PKCS#12 derive key : MD5: Zero length password and hash
+depends_on:MBEDTLS_MD5_C
+pkcs12_derive_key:MBEDTLS_MD_MD5:48:"":USE_GIVEN_INPUT:"":USE_GIVEN_INPUT:3:"6afdcbd5ebf943272134f1c3de2dc11b6afdcbd5ebf943272134f1c3de2dc11b6afdcbd5ebf943272134f1c3de2dc11b":0
+
+PKCS#12 derive key: MD5: NULL password and hash
+depends_on:MBEDTLS_MD5_C
+pkcs12_derive_key:MBEDTLS_MD_MD5:48:"":USE_NULL_INPUT:"":USE_NULL_INPUT:3:"6afdcbd5ebf943272134f1c3de2dc11b6afdcbd5ebf943272134f1c3de2dc11b6afdcbd5ebf943272134f1c3de2dc11b":0
+
+PKCS#12 derive key: MD5: Zero length password
+depends_on:MBEDTLS_MD5_C
+pkcs12_derive_key:MBEDTLS_MD_MD5:48:"":USE_GIVEN_INPUT:"0123456789abcdef":USE_GIVEN_INPUT:3:"832d8502114fcccfd3de0c2b2863b1c45fb92a8db2ed1e704727b324adc267bdd66ae4918a81fa2d1ba15febfb9e6c4e":0
+
+PKCS#12 derive key: MD5: NULL password
+depends_on:MBEDTLS_MD5_C
+pkcs12_derive_key:MBEDTLS_MD_MD5:48:"":USE_NULL_INPUT:"0123456789abcdef":USE_GIVEN_INPUT:3:"832d8502114fcccfd3de0c2b2863b1c45fb92a8db2ed1e704727b324adc267bdd66ae4918a81fa2d1ba15febfb9e6c4e":0
+
+PKCS#12 derive key: MD5: Invalid length NULL password
+depends_on:MBEDTLS_MD5_C
+pkcs12_derive_key:MBEDTLS_MD_MD5:48:"0123456789abcdef":USE_NULL_INPUT:"0123456789abcdef":USE_GIVEN_INPUT:3:"":MBEDTLS_ERR_PKCS12_BAD_INPUT_DATA
+
+PKCS#12 derive key: MD5: Zero length salt
+depends_on:MBEDTLS_MD5_C
+pkcs12_derive_key:MBEDTLS_MD_MD5:48:"0123456789abcdef":USE_GIVEN_INPUT:"":USE_GIVEN_INPUT:3:"832d8502114fcccfd3de0c2b2863b1c45fb92a8db2ed1e704727b324adc267bdd66ae4918a81fa2d1ba15febfb9e6c4e":0
+
+PKCS#12 derive key: MD5: NULL salt
+depends_on:MBEDTLS_MD5_C
+pkcs12_derive_key:MBEDTLS_MD_MD5:48:"0123456789abcdef":USE_GIVEN_INPUT:"":USE_NULL_INPUT:3:"832d8502114fcccfd3de0c2b2863b1c45fb92a8db2ed1e704727b324adc267bdd66ae4918a81fa2d1ba15febfb9e6c4e":0
+
+PKCS#12 derive key: MD5: Invalid length NULL salt
+depends_on:MBEDTLS_MD5_C
+pkcs12_derive_key:MBEDTLS_MD_MD5:48:"0123456789abcdef":USE_GIVEN_INPUT:"0123456789abcdef":USE_NULL_INPUT:3:"":MBEDTLS_ERR_PKCS12_BAD_INPUT_DATA
+
+PKCS#12 derive key: MD5: Valid password and salt
+depends_on:MBEDTLS_MD5_C
+pkcs12_derive_key:MBEDTLS_MD_MD5:48:"0123456789abcdef":USE_GIVEN_INPUT:"0123456789abcdef":USE_GIVEN_INPUT:3:"46559deeee036836ab1b633ec620178d4c70eacf42f72a2ad7360c812efa09ca3d7567b489a109050345c2dc6a262995":0
diff --git a/tests/suites/test_suite_pkcs12.function b/tests/suites/test_suite_pkcs12.function
new file mode 100644
index 0000000..54dc042
--- /dev/null
+++ b/tests/suites/test_suite_pkcs12.function
@@ -0,0 +1,69 @@
+/* BEGIN_HEADER */
+#include "mbedtls/pkcs12.h"
+#include "common.h"
+
+typedef enum
+{
+   USE_NULL_INPUT = 0,
+   USE_GIVEN_INPUT = 1,
+} input_usage_method_t;
+
+/* END_HEADER */
+
+/* BEGIN_DEPENDENCIES
+ * depends_on:MBEDTLS_PKCS12_C
+ * END_DEPENDENCIES
+ */
+
+/* BEGIN_CASE */
+void pkcs12_derive_key( int md_type, int key_size_arg,
+                        data_t *password_arg, int password_usage,
+                        data_t *salt_arg, int salt_usage,
+                        int iterations,
+                        data_t* expected_output, int expected_status )
+
+{
+   int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED;
+   unsigned char *output_data = NULL;
+
+   unsigned char *password = NULL;
+   size_t password_len = 0;
+   unsigned char *salt = NULL;
+   size_t salt_len = 0;
+   size_t key_size = key_size_arg;
+
+   if( password_usage == USE_GIVEN_INPUT )
+      password = password_arg->x;
+
+   password_len = password_arg->len;
+
+   if( salt_usage == USE_GIVEN_INPUT )
+      salt = salt_arg->x;
+
+   salt_len = salt_arg->len;
+
+   ASSERT_ALLOC( output_data, key_size );
+
+   ret = mbedtls_pkcs12_derivation( output_data,
+                                    key_size,
+                                    password,
+                                    password_len,
+                                    salt,
+                                    salt_len,
+                                    md_type,
+                                    MBEDTLS_PKCS12_DERIVE_KEY,
+                                    iterations );
+
+   TEST_EQUAL( ret, expected_status );
+
+   if( expected_status == 0 )
+   {
+      ASSERT_COMPARE( expected_output->x, expected_output->len,
+                      output_data, key_size );
+   }
+
+exit:
+   mbedtls_free( output_data );
+
+}
+/* END_CASE */