Check hamming weights of AES key shares (#252)

* Add hamming weight checks

* Rework to deal with whole shares, not word-wise

* Fix logging

* Revert "Rework to deal with whole shares, not word-wise"

This reverts commit 65245f897abe467c50e24ffc618987c46498f577.

* Switch to checking half-words, and add fixed number of attempts (100,000) before failing

* Switch to checking numbers are within middle ranges, rather than checking deltas
diff --git a/main.cpp b/main.cpp
index 860928f..0aef1b9 100644
--- a/main.cpp
+++ b/main.cpp
@@ -5189,18 +5189,78 @@
         }
     }
 
+    int min_weight = 6;
+    int max_weight = 10;
+    int min_combined_weight = 28;
+    int max_combined_weight = 36;
+
     if (!keyIsShare) {
+        std::vector<int> num_attempts;
         // Generate a random key share from 256-bit key
         std::random_device rand{};
         assert(rand.max() - rand.min() >= 256);
         for(int i=0; i < 8; i++) {
-            for (int j=0; j < 12; j++) {
-                aes_key_share.bytes[i*16 + j] = rand();
+            // Regenerate the share word until the hamming weights are close to each other
+            bool pass = false;
+            for (int attempt=0; attempt < 100000; attempt++) {
+                for (int j=0; j < 12; j++) {
+                    aes_key_share.bytes[i*16 + j] = rand();
+                }
+                aes_key_share.words[i*4 + 3] = aes_key.words[i]
+                                            ^ aes_key_share.words[i*4]
+                                            ^ aes_key_share.words[i*4 + 1]
+                                            ^ aes_key_share.words[i*4 + 2];
+
+                pass = true;
+                for (int half=0; half < 2; half++) {
+                    uint8_t combined_weight = 0;
+                    for (int j=0; j < 4; j++) {
+                        uint16_t half_word = aes_key_share.words[i*4 + j] >> half*16;
+                        uint8_t weight = __builtin_popcount(half_word);
+                        if (weight < min_weight || weight > max_weight) {
+                            DEBUG_LOG("Generated share %d word %d half %d has hamming weights out of range %d -> %d - regenerating attempt %d\n", j, i, half, min_weight, max_weight, attempt);
+                            pass = false;
+                            break;
+                        }
+                        combined_weight += weight;
+                    }
+                    if (!pass) break;
+                    if (combined_weight < min_combined_weight || combined_weight > max_combined_weight) {
+                        DEBUG_LOG("Generated share word %d half %d has hamming weights out of range %d -> %d - regenerating attempt %d\n", i, half, min_combined_weight, max_combined_weight, attempt);
+                        pass = false;
+                        break;
+                    }
+                }
+                if (pass) {
+                    num_attempts.push_back(attempt);
+                    break;
+                }
             }
-            aes_key_share.words[i*4 + 3] = aes_key.words[i]
-                                        ^ aes_key_share.words[i*4]
-                                        ^ aes_key_share.words[i*4 + 1]
-                                        ^ aes_key_share.words[i*4 + 2];
+            if (!pass) {
+                fail(ERROR_INCOMPATIBLE, "Failed to generate a share word with hamming weights within %d -> %d", min_weight, max_weight);
+            }
+        }
+
+        DEBUG_LOG("Average number of attempts: %d\n", std::accumulate(num_attempts.begin(), num_attempts.end(), 0) / num_attempts.size());
+        DEBUG_LOG("Max number of attempts: %d\n", *std::max_element(num_attempts.begin(), num_attempts.end()));
+        DEBUG_LOG("Min number of attempts: %d\n", *std::min_element(num_attempts.begin(), num_attempts.end()));
+    } else {
+        // Check the share word hamming weights are close to each other
+        for(int i=0; i < 8; i++) {
+            for (int half=0; half < 2; half++) {
+                uint8_t combined_weight = 0;
+                for (int j=0; j < 4; j++) {
+                    uint16_t half_word = aes_key_share.words[i*4 + j] >> half*16;
+                    uint8_t weight = __builtin_popcount(half_word);
+                    if (weight < min_weight || weight > max_weight) {
+                        std::cout << "WARNING: Key Share " << j << " Word " << i << " half " << half << " has hamming weights out of range " << min_weight << " -> " << max_weight << " - this may leak information about the key\n";
+                    }
+                    combined_weight += weight;
+                }
+                if (combined_weight < min_combined_weight || combined_weight > max_combined_weight) {
+                    std::cout << "WARNING: Key Share Word " << i << " half " << half << " has hamming weights out of range " << min_combined_weight << " -> " << max_combined_weight << " - this may leak information about the key\n";
+                }
+            }
         }
     }