Add rsa_rsassa_pss_verify_ext()
diff --git a/library/rsa.c b/library/rsa.c
index 1e84d9f..32d829c 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -1106,14 +1106,16 @@
 /*
  * Implementation of the PKCS#1 v2.1 RSASSA-PSS-VERIFY function
  */
-int rsa_rsassa_pss_verify( rsa_context *ctx,
-                           int (*f_rng)(void *, unsigned char *, size_t),
-                           void *p_rng,
-                           int mode,
-                           md_type_t md_alg,
-                           unsigned int hashlen,
-                           const unsigned char *hash,
-                           const unsigned char *sig )
+int rsa_rsassa_pss_verify_ext( rsa_context *ctx,
+                               int (*f_rng)(void *, unsigned char *, size_t),
+                               void *p_rng,
+                               int mode,
+                               md_type_t md_alg,
+                               unsigned int hashlen,
+                               const unsigned char *hash,
+                               md_type_t mgf1_hash_id,
+                               int expected_salt_len,
+                               const unsigned char *sig )
 {
     int ret;
     size_t siglen;
@@ -1157,13 +1159,12 @@
         hashlen = md_get_size( md_info );
     }
 
-    md_info = md_info_from_type( ctx->hash_id != POLARSSL_MD_NONE ?
-                                 ctx->hash_id : md_alg );
+    md_info = md_info_from_type( mgf1_hash_id );
     if( md_info == NULL )
         return( POLARSSL_ERR_RSA_BAD_INPUT_DATA );
 
     hlen = md_get_size( md_info );
-    slen = siglen - hlen - 1;
+    slen = siglen - hlen - 1; /* Currently length of salt + padding */
 
     memset( zeros, 0, 8 );
 
@@ -1197,8 +1198,15 @@
         return( POLARSSL_ERR_RSA_INVALID_PADDING );
     }
 
+    /* Actual salt len */
     slen -= p - buf;
 
+    if( expected_salt_len != RSA_SALT_LEN_ANY &&
+        slen != (size_t) expected_salt_len )
+    {
+        return( POLARSSL_ERR_RSA_INVALID_PADDING );
+    }
+
     // Generate H = Hash( M' )
     //
     md_starts( &md_ctx );
@@ -1214,6 +1222,29 @@
     else
         return( POLARSSL_ERR_RSA_VERIFY_FAILED );
 }
+
+/*
+ * Simplified PKCS#1 v2.1 RSASSA-PSS-VERIFY function
+ */
+int rsa_rsassa_pss_verify( rsa_context *ctx,
+                           int (*f_rng)(void *, unsigned char *, size_t),
+                           void *p_rng,
+                           int mode,
+                           md_type_t md_alg,
+                           unsigned int hashlen,
+                           const unsigned char *hash,
+                           const unsigned char *sig )
+{
+    md_type_t mgf1_hash_id = ( ctx->hash_id != POLARSSL_MD_NONE )
+                             ? ctx->hash_id
+                             : md_alg;
+
+    return( rsa_rsassa_pss_verify_ext( ctx, f_rng, p_rng, mode,
+                                       md_alg, hashlen, hash,
+                                       mgf1_hash_id, RSA_SALT_LEN_ANY,
+                                       sig ) );
+
+}
 #endif /* POLARSSL_PKCS1_V21 */
 
 #if defined(POLARSSL_PKCS1_V15)