Merge pull request #8822 from daverodgman/sha3-perf

SHA-3 performance & code size
diff --git a/library/sha3.c b/library/sha3.c
index 27d495f..5738559 100644
--- a/library/sha3.c
+++ b/library/sha3.c
@@ -14,6 +14,33 @@
 
 #if defined(MBEDTLS_SHA3_C)
 
+/*
+ * These macros select manually unrolled implementations of parts of the main permutation function.
+ *
+ * Unrolling has a major impact on both performance and code size. gcc performance benefits a lot
+ * from manually unrolling at higher optimisation levels.
+ *
+ * Depending on your size/perf priorities, compiler and target, it may be beneficial to adjust
+ * these; the defaults here should give sensible trade-offs for gcc and clang on aarch64 and
+ * x86-64.
+ */
+#if !defined(MBEDTLS_SHA3_THETA_UNROLL)
+    #define MBEDTLS_SHA3_THETA_UNROLL 0 //no-check-names
+#endif
+#if !defined(MBEDTLS_SHA3_CHI_UNROLL)
+    #if defined(__OPTIMIZE_SIZE__)
+        #define MBEDTLS_SHA3_CHI_UNROLL 0 //no-check-names
+    #else
+        #define MBEDTLS_SHA3_CHI_UNROLL 1 //no-check-names
+    #endif
+#endif
+#if !defined(MBEDTLS_SHA3_PI_UNROLL)
+    #define MBEDTLS_SHA3_PI_UNROLL 1 //no-check-names
+#endif
+#if !defined(MBEDTLS_SHA3_RHO_UNROLL)
+    #define MBEDTLS_SHA3_RHO_UNROLL 1 //no-check-names
+#endif
+
 #include "mbedtls/sha3.h"
 #include "mbedtls/platform_util.h"
 #include "mbedtls/error.h"
@@ -56,18 +83,15 @@
 };
 #undef H
 
-static const uint8_t rho[24] = {
-    1, 62, 28, 27, 36, 44,  6, 55, 20,
-    3, 10, 43, 25, 39, 41, 45, 15,
-    21,  8, 18,  2, 61, 56, 14
+static const uint32_t rho[6] = {
+    0x3f022425, 0x1c143a09, 0x2c3d3615, 0x27191713, 0x312b382e, 0x3e030832
 };
 
-static const uint8_t pi[24] = {
-    10,  7, 11, 17, 18, 3,  5, 16,  8, 21, 24, 4,
-    15, 23, 19, 13, 12, 2, 20, 14, 22,  9,  6, 1,
+static const uint32_t pi[6] = {
+    0x110b070a, 0x10050312, 0x04181508, 0x0d13170f, 0x0e14020c, 0x01060916
 };
 
-#define ROT64(x, y) (((x) << (y)) | ((x) >> (64U - (y))))
+#define ROTR64(x, y) (((x) << (64U - (y))) | ((x) >> (y))) // 64-bit rotate right
 #define ABSORB(ctx, idx, v) do { ctx->state[(idx) >> 3] ^= ((uint64_t) (v)) << (((idx) & 0x7) << 3); \
 } while (0)
 #define SQUEEZE(ctx, idx) ((uint8_t) (ctx->state[(idx) >> 3] >> (((idx) & 0x7) << 3)))
@@ -84,39 +108,97 @@
         uint64_t t;
 
         /* Theta */
+#if MBEDTLS_SHA3_THETA_UNROLL == 0 //no-check-names
+        for (i = 0; i < 5; i++) {
+            lane[i] = s[i] ^ s[i + 5] ^ s[i + 10] ^ s[i + 15] ^ s[i + 20];
+        }
+        for (i = 0; i < 5; i++) {
+            t = lane[(i + 4) % 5] ^ ROTR64(lane[(i + 1) % 5], 63);
+            s[i] ^= t; s[i + 5] ^= t; s[i + 10] ^= t; s[i + 15] ^= t; s[i + 20] ^= t;
+        }
+#else
         lane[0] = s[0] ^ s[5] ^ s[10] ^ s[15] ^ s[20];
         lane[1] = s[1] ^ s[6] ^ s[11] ^ s[16] ^ s[21];
         lane[2] = s[2] ^ s[7] ^ s[12] ^ s[17] ^ s[22];
         lane[3] = s[3] ^ s[8] ^ s[13] ^ s[18] ^ s[23];
         lane[4] = s[4] ^ s[9] ^ s[14] ^ s[19] ^ s[24];
 
-        t = lane[4] ^ ROT64(lane[1], 1);
+        t = lane[4] ^ ROTR64(lane[1], 63);
         s[0] ^= t; s[5] ^= t; s[10] ^= t; s[15] ^= t; s[20] ^= t;
 
-        t = lane[0] ^ ROT64(lane[2], 1);
+        t = lane[0] ^ ROTR64(lane[2], 63);
         s[1] ^= t; s[6] ^= t; s[11] ^= t; s[16] ^= t; s[21] ^= t;
 
-        t = lane[1] ^ ROT64(lane[3], 1);
+        t = lane[1] ^ ROTR64(lane[3], 63);
         s[2] ^= t; s[7] ^= t; s[12] ^= t; s[17] ^= t; s[22] ^= t;
 
-        t = lane[2] ^ ROT64(lane[4], 1);
+        t = lane[2] ^ ROTR64(lane[4], 63);
         s[3] ^= t; s[8] ^= t; s[13] ^= t; s[18] ^= t; s[23] ^= t;
 
-        t = lane[3] ^ ROT64(lane[0], 1);
+        t = lane[3] ^ ROTR64(lane[0], 63);
         s[4] ^= t; s[9] ^= t; s[14] ^= t; s[19] ^= t; s[24] ^= t;
+#endif
 
         /* Rho */
-        for (i = 1; i < 25; i++) {
-            s[i] = ROT64(s[i], rho[i-1]);
+        for (i = 1; i < 25; i += 4) {
+            uint32_t r = rho[(i - 1) >> 2];
+#if MBEDTLS_SHA3_RHO_UNROLL == 0
+            for (int j = i; j < i + 4; j++) {
+                uint8_t r8 = (uint8_t) (r >> 24);
+                r <<= 8;
+                s[j] = ROTR64(s[j], r8);
+            }
+#else
+            s[i + 0] = ROTR64(s[i + 0], MBEDTLS_BYTE_3(r));
+            s[i + 1] = ROTR64(s[i + 1], MBEDTLS_BYTE_2(r));
+            s[i + 2] = ROTR64(s[i + 2], MBEDTLS_BYTE_1(r));
+            s[i + 3] = ROTR64(s[i + 3], MBEDTLS_BYTE_0(r));
+#endif
         }
 
         /* Pi */
         t = s[1];
-        for (i = 0; i < 24; i++) {
-            SWAP(s[pi[i]], t);
+#if MBEDTLS_SHA3_PI_UNROLL == 0
+        for (i = 0; i < 24; i += 4) {
+            uint32_t p = pi[i >> 2];
+            for (unsigned j = 0; j < 4; j++) {
+                SWAP(s[p & 0xff], t);
+                p >>= 8;
+            }
         }
+#else
+        uint32_t p = pi[0];
+        SWAP(s[MBEDTLS_BYTE_0(p)], t); SWAP(s[MBEDTLS_BYTE_1(p)], t);
+        SWAP(s[MBEDTLS_BYTE_2(p)], t); SWAP(s[MBEDTLS_BYTE_3(p)], t);
+        p = pi[1];
+        SWAP(s[MBEDTLS_BYTE_0(p)], t); SWAP(s[MBEDTLS_BYTE_1(p)], t);
+        SWAP(s[MBEDTLS_BYTE_2(p)], t); SWAP(s[MBEDTLS_BYTE_3(p)], t);
+        p = pi[2];
+        SWAP(s[MBEDTLS_BYTE_0(p)], t); SWAP(s[MBEDTLS_BYTE_1(p)], t);
+        SWAP(s[MBEDTLS_BYTE_2(p)], t); SWAP(s[MBEDTLS_BYTE_3(p)], t);
+        p = pi[3];
+        SWAP(s[MBEDTLS_BYTE_0(p)], t); SWAP(s[MBEDTLS_BYTE_1(p)], t);
+        SWAP(s[MBEDTLS_BYTE_2(p)], t); SWAP(s[MBEDTLS_BYTE_3(p)], t);
+        p = pi[4];
+        SWAP(s[MBEDTLS_BYTE_0(p)], t); SWAP(s[MBEDTLS_BYTE_1(p)], t);
+        SWAP(s[MBEDTLS_BYTE_2(p)], t); SWAP(s[MBEDTLS_BYTE_3(p)], t);
+        p = pi[5];
+        SWAP(s[MBEDTLS_BYTE_0(p)], t); SWAP(s[MBEDTLS_BYTE_1(p)], t);
+        SWAP(s[MBEDTLS_BYTE_2(p)], t); SWAP(s[MBEDTLS_BYTE_3(p)], t);
+#endif
 
         /* Chi */
+#if MBEDTLS_SHA3_CHI_UNROLL == 0 //no-check-names
+        for (i = 0; i <= 20; i += 5) {
+            lane[0] = s[i]; lane[1] = s[i + 1]; lane[2] = s[i + 2];
+            lane[3] = s[i + 3]; lane[4] = s[i + 4];
+            s[i + 0] ^= (~lane[1]) & lane[2];
+            s[i + 1] ^= (~lane[2]) & lane[3];
+            s[i + 2] ^= (~lane[3]) & lane[4];
+            s[i + 3] ^= (~lane[4]) & lane[0];
+            s[i + 4] ^= (~lane[0]) & lane[1];
+        }
+#else
         lane[0] = s[0]; lane[1] = s[1]; lane[2] = s[2]; lane[3] = s[3]; lane[4] = s[4];
         s[0] ^= (~lane[1]) & lane[2];
         s[1] ^= (~lane[2]) & lane[3];
@@ -151,6 +233,7 @@
         s[22] ^= (~lane[3]) & lane[4];
         s[23] ^= (~lane[4]) & lane[0];
         s[24] ^= (~lane[0]) & lane[1];
+#endif
 
         /* Iota */
         /* Decompress the round masks (see definition of rc) */
diff --git a/tests/scripts/all.sh b/tests/scripts/all.sh
index 9191f65..21071bd 100755
--- a/tests/scripts/all.sh
+++ b/tests/scripts/all.sh
@@ -4781,6 +4781,26 @@
     not grep -q "AES note: built-in implementation." ./programs/test/selftest
 }
 
+component_test_sha3_variations() {
+    msg "sha3 loop unroll variations"
+
+    # define minimal config sufficient to test SHA3
+    cat > include/mbedtls/mbedtls_config.h << END
+        #define MBEDTLS_SELF_TEST
+        #define MBEDTLS_SHA3_C
+END
+
+    msg "all loops unrolled"
+    make clean
+    make -C tests test_suite_shax CFLAGS="-DMBEDTLS_SHA3_THETA_UNROLL=1 -DMBEDTLS_SHA3_PI_UNROLL=1 -DMBEDTLS_SHA3_CHI_UNROLL=1 -DMBEDTLS_SHA3_RHO_UNROLL=1"
+    ./tests/test_suite_shax
+
+    msg "all loops rolled up"
+    make clean
+    make -C tests test_suite_shax CFLAGS="-DMBEDTLS_SHA3_THETA_UNROLL=0 -DMBEDTLS_SHA3_PI_UNROLL=0 -DMBEDTLS_SHA3_CHI_UNROLL=0 -DMBEDTLS_SHA3_RHO_UNROLL=0"
+    ./tests/test_suite_shax
+}
+
 support_test_aesni_m32() {
     support_test_m32_no_asm && (lscpu | grep -qw aes)
 }