Refactored RSA to have random generator in every RSA operation

Primarily so that rsa_private() receives an RNG for blinding purposes.
diff --git a/library/rsa.c b/library/rsa.c
index ccdd048..2cb3742 100644
--- a/library/rsa.c
+++ b/library/rsa.c
@@ -257,12 +257,16 @@
  * Do an RSA private key operation
  */
 int rsa_private( rsa_context *ctx,
+                 int (*f_rng)(void *, unsigned char *, size_t),
+                 void *p_rng,
                  const unsigned char *input,
                  unsigned char *output )
 {
     int ret;
     size_t olen;
     mpi T, T1, T2;
+    ((void) f_rng);
+    ((void) p_rng);
 
     mpi_init( &T ); mpi_init( &T1 ); mpi_init( &T2 );
 
@@ -430,7 +434,7 @@
 
     return( ( mode == RSA_PUBLIC )
             ? rsa_public(  ctx, output, output )
-            : rsa_private( ctx, output, output ) );
+            : rsa_private( ctx, f_rng, p_rng, output, output ) );
 }
 #endif /* POLARSSL_PKCS1_V21 */
 
@@ -492,7 +496,7 @@
 
     return( ( mode == RSA_PUBLIC )
             ? rsa_public(  ctx, output, output )
-            : rsa_private( ctx, output, output ) );
+            : rsa_private( ctx, f_rng, p_rng, output, output ) );
 }
 
 /*
@@ -527,7 +531,9 @@
  * Implementation of the PKCS#1 v2.1 RSAES-OAEP-DECRYPT function
  */
 int rsa_rsaes_oaep_decrypt( rsa_context *ctx,
-                            int mode, 
+                            int (*f_rng)(void *, unsigned char *, size_t),
+                            void *p_rng,
+                            int mode,
                             const unsigned char *label, size_t label_len,
                             size_t *olen,
                             const unsigned char *input,
@@ -553,7 +559,7 @@
 
     ret = ( mode == RSA_PUBLIC )
           ? rsa_public(  ctx, input, buf )
-          : rsa_private( ctx, input, buf );
+          : rsa_private( ctx, f_rng, p_rng, input, buf );
 
     if( ret != 0 )
         return( ret );
@@ -618,6 +624,8 @@
  * Implementation of the PKCS#1 v2.1 RSAES-PKCS1-V1_5-DECRYPT function
  */
 int rsa_rsaes_pkcs1_v15_decrypt( rsa_context *ctx,
+                                 int (*f_rng)(void *, unsigned char *, size_t),
+                                 void *p_rng,
                                  int mode, size_t *olen,
                                  const unsigned char *input,
                                  unsigned char *output,
@@ -639,7 +647,7 @@
 
     ret = ( mode == RSA_PUBLIC )
           ? rsa_public(  ctx, input, buf )
-          : rsa_private( ctx, input, buf );
+          : rsa_private( ctx, f_rng, p_rng, input, buf );
 
     if( ret != 0 )
         return( ret );
@@ -711,6 +719,8 @@
  * Do an RSA operation, then remove the message padding
  */
 int rsa_pkcs1_decrypt( rsa_context *ctx,
+                       int (*f_rng)(void *, unsigned char *, size_t),
+                       void *p_rng,
                        int mode, size_t *olen,
                        const unsigned char *input,
                        unsigned char *output,
@@ -719,13 +729,14 @@
     switch( ctx->padding )
     {
         case RSA_PKCS_V15:
-            return rsa_rsaes_pkcs1_v15_decrypt( ctx, mode, olen, input, output,
-                                                output_max_len );
+            return rsa_rsaes_pkcs1_v15_decrypt( ctx, f_rng, p_rng, mode, olen,
+                                                input, output, output_max_len );
 
 #if defined(POLARSSL_PKCS1_V21)
         case RSA_PKCS_V21:
-            return rsa_rsaes_oaep_decrypt( ctx, mode, NULL, 0, olen, input,
-                                           output, output_max_len );
+            return rsa_rsaes_oaep_decrypt( ctx, f_rng, p_rng, mode, NULL, 0,
+                                           olen, input, output,
+                                           output_max_len );
 #endif
 
         default:
@@ -827,7 +838,7 @@
 
     return( ( mode == RSA_PUBLIC )
             ? rsa_public(  ctx, sig, sig )
-            : rsa_private( ctx, sig, sig ) );
+            : rsa_private( ctx, f_rng, p_rng, sig, sig ) );
 }
 #endif /* POLARSSL_PKCS1_V21 */
 
@@ -838,6 +849,8 @@
  * Do an RSA operation to sign the message digest
  */
 int rsa_rsassa_pkcs1_v15_sign( rsa_context *ctx,
+                               int (*f_rng)(void *, unsigned char *, size_t),
+                               void *p_rng,
                                int mode,
                                md_type_t md_alg,
                                unsigned int hashlen,
@@ -912,7 +925,7 @@
 
     return( ( mode == RSA_PUBLIC )
             ? rsa_public(  ctx, sig, sig )
-            : rsa_private( ctx, sig, sig ) );
+            : rsa_private( ctx, f_rng, p_rng, sig, sig ) );
 }
 
 /*
@@ -930,7 +943,7 @@
     switch( ctx->padding )
     {
         case RSA_PKCS_V15:
-            return rsa_rsassa_pkcs1_v15_sign( ctx, mode, md_alg,
+            return rsa_rsassa_pkcs1_v15_sign( ctx, f_rng, p_rng, mode, md_alg,
                                               hashlen, hash, sig );
 
 #if defined(POLARSSL_PKCS1_V21)
@@ -949,6 +962,8 @@
  * Implementation of the PKCS#1 v2.1 RSASSA-PSS-VERIFY function
  */
 int rsa_rsassa_pss_verify( rsa_context *ctx,
+                           int (*f_rng)(void *, unsigned char *, size_t),
+                           void *p_rng,
                            int mode,
                            md_type_t md_alg,
                            unsigned int hashlen,
@@ -976,7 +991,7 @@
 
     ret = ( mode == RSA_PUBLIC )
           ? rsa_public(  ctx, sig, buf )
-          : rsa_private( ctx, sig, buf );
+          : rsa_private( ctx, f_rng, p_rng, sig, buf );
 
     if( ret != 0 )
         return( ret );
@@ -1059,6 +1074,8 @@
  * Implementation of the PKCS#1 v2.1 RSASSA-PKCS1-v1_5-VERIFY function
  */
 int rsa_rsassa_pkcs1_v15_verify( rsa_context *ctx,
+                                 int (*f_rng)(void *, unsigned char *, size_t),
+                                 void *p_rng,
                                  int mode,
                                  md_type_t md_alg,
                                  unsigned int hashlen,
@@ -1083,7 +1100,7 @@
 
     ret = ( mode == RSA_PUBLIC )
           ? rsa_public(  ctx, sig, buf )
-          : rsa_private( ctx, sig, buf );
+          : rsa_private( ctx, f_rng, p_rng, sig, buf );
 
     if( ret != 0 )
         return( ret );
@@ -1173,6 +1190,8 @@
  * Do an RSA operation and check the message digest
  */
 int rsa_pkcs1_verify( rsa_context *ctx,
+                      int (*f_rng)(void *, unsigned char *, size_t),
+                      void *p_rng,
                       int mode,
                       md_type_t md_alg,
                       unsigned int hashlen,
@@ -1182,12 +1201,12 @@
     switch( ctx->padding )
     {
         case RSA_PKCS_V15:
-            return rsa_rsassa_pkcs1_v15_verify( ctx, mode, md_alg,
+            return rsa_rsassa_pkcs1_v15_verify( ctx, f_rng, p_rng, mode, md_alg,
                                                 hashlen, hash, sig );
 
 #if defined(POLARSSL_PKCS1_V21)
         case RSA_PKCS_V21:
-            return rsa_rsassa_pss_verify( ctx, mode, md_alg,
+            return rsa_rsassa_pss_verify( ctx, f_rng, p_rng, mode, md_alg,
                                           hashlen, hash, sig );
 #endif
 
@@ -1355,7 +1374,7 @@
 
     memcpy( rsa_plaintext, RSA_PT, PT_LEN );
 
-    if( rsa_pkcs1_encrypt( &rsa, &myrand, NULL, RSA_PUBLIC, PT_LEN,
+    if( rsa_pkcs1_encrypt( &rsa, myrand, NULL, RSA_PUBLIC, PT_LEN,
                            rsa_plaintext, rsa_ciphertext ) != 0 )
     {
         if( verbose != 0 )
@@ -1367,7 +1386,7 @@
     if( verbose != 0 )
         printf( "passed\n  PKCS#1 decryption : " );
 
-    if( rsa_pkcs1_decrypt( &rsa, RSA_PRIVATE, &len,
+    if( rsa_pkcs1_decrypt( &rsa, myrand, NULL, RSA_PRIVATE, &len,
                            rsa_ciphertext, rsa_decrypted,
                            sizeof(rsa_decrypted) ) != 0 )
     {
@@ -1403,7 +1422,7 @@
     if( verbose != 0 )
         printf( "passed\n  PKCS#1 sig. verify: " );
 
-    if( rsa_pkcs1_verify( &rsa, RSA_PUBLIC, POLARSSL_MD_SHA1, 0,
+    if( rsa_pkcs1_verify( &rsa, NULL, NULL, RSA_PUBLIC, POLARSSL_MD_SHA1, 0,
                           sha1sum, rsa_ciphertext ) != 0 )
     {
         if( verbose != 0 )