AES256
diff --git a/include/picotls/fusion.h b/include/picotls/fusion.h
index 786df27..85cdc0d 100644
--- a/include/picotls/fusion.h
+++ b/include/picotls/fusion.h
@@ -30,23 +30,26 @@
#include <emmintrin.h>
#include "../picotls.h"
-#define PTLS_FUSION_AES_ROUNDS 10 /* TODO support AES256 */
+#define PTLS_FUSION_AES128_ROUNDS 10
+#define PTLS_FUSION_AES256_ROUNDS 14
typedef struct ptls_fusion_aesecb_context {
- __m128i keys[PTLS_FUSION_AES_ROUNDS + 1];
+ __m128i keys[PTLS_FUSION_AES256_ROUNDS + 1];
+ unsigned rounds;
} ptls_fusion_aesecb_context_t;
typedef struct ptls_fusion_aesgcm_context ptls_fusion_aesgcm_context_t;
-void ptls_fusion_aesecb_init(ptls_fusion_aesecb_context_t *ctx, const void *key);
+void ptls_fusion_aesecb_init(ptls_fusion_aesecb_context_t *ctx, int is_enc, const void *key, size_t key_size);
void ptls_fusion_aesecb_dispose(ptls_fusion_aesecb_context_t *ctx);
+void ptls_fusion_aesecb_encrypt(ptls_fusion_aesecb_context_t *ctx, void *dst, const void *src);
/**
* Creates an AES-GCM context.
* @param key the AES key (128 bits)
* @param max_size maximum size of the record (i.e. AAD + encrypted payload)
*/
-ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t max_size);
+ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t key_size, size_t max_size);
/**
* Destroys an AES-GCM context.
*/
diff --git a/lib/fusion.c b/lib/fusion.c
index 70b229a..8be2bcb 100644
--- a/lib/fusion.c
+++ b/lib/fusion.c
@@ -187,7 +187,7 @@
size_t i;
v = _mm_xor_si128(v, ctx->keys[0]);
- for (i = 1; i < PTLS_FUSION_AES_ROUNDS; ++i)
+ for (i = 1; i < ctx->rounds; ++i)
v = _mm_aesenc_si128(v, ctx->keys[i]);
v = _mm_aesenclast_si128(v, ctx->keys[i]);
@@ -272,14 +272,14 @@
} while (0)
/* aesenclast */
-#define AESECB6_FINAL() \
+#define AESECB6_FINAL(i) \
do { \
- __m128i k = ctx->ecb.keys[10]; \
+ __m128i k = ctx->ecb.keys[i]; \
bits0 = _mm_aesenclast_si128(bits0, k); \
bits1 = _mm_aesenclast_si128(bits1, k); \
bits2 = _mm_aesenclast_si128(bits2, k); \
bits3 = _mm_aesenclast_si128(bits3, k); \
- bits4 = _mm_aesenclast_si128(bits4, bits4keys[10]); \
+ bits4 = _mm_aesenclast_si128(bits4, bits4keys[i]); \
bits5 = _mm_aesenclast_si128(bits5, k); \
} while (0)
@@ -313,11 +313,13 @@
ctr = _mm_insert_epi32(ctr, 1, 0);
ek0 = _mm_shuffle_epi8(ctr, bswap8);
- /* prepare the first bit stream */
- AESECB6_INIT();
- for (size_t i = 1; i < 10; ++i)
- AESECB6_UPDATE(i);
- AESECB6_FINAL();
+ { /* prepare the first bit stream */
+ size_t i;
+ AESECB6_INIT();
+ for (i = 1; i < ctx->ecb.rounds; ++i)
+ AESECB6_UPDATE(i);
+ AESECB6_FINAL(i);
+ }
/* the main loop */
while (1) {
@@ -416,13 +418,14 @@
}
/* run AES and multiplication in parallel */
- for (size_t i = 2; i <= 7; ++i) {
+ size_t i;
+ for (i = 2; i <= 7; ++i) {
AESECB6_UPDATE(i);
gfmul_onestep(&gstate, _mm_loadu_si128(gdata++), --ghash_precompute);
}
- AESECB6_UPDATE(8);
- AESECB6_UPDATE(9);
- AESECB6_FINAL();
+ for (; i < ctx->ecb.rounds; ++i)
+ AESECB6_UPDATE(i);
+ AESECB6_FINAL(i);
}
Finish:
@@ -450,8 +453,8 @@
}
do {
bits4 = _mm_aesenc_si128(bits4, bits4keys[i++]);
- } while (i != 10);
- bits4 = _mm_aesenclast_si128(bits4, bits4keys[10]);
+ } while (i != ctx->ecb.rounds);
+ bits4 = _mm_aesenclast_si128(bits4, bits4keys[i]);
_mm_storeu_si128((__m128i *)supp->output, bits4);
}
@@ -571,7 +574,7 @@
AESECB6_UPDATE(aesi);
gfmul_onestep(&gstate, _mm_loadu_si128(gdata++), --ghash_precompute);
}
- for (; aesi <= 9; ++aesi)
+ for (; aesi < ctx->ecb.rounds; ++aesi)
AESECB6_UPDATE(aesi);
__m128i k = ctx->ecb.keys[aesi];
bits0 = _mm_aesenclast_si128(bits0, k);
@@ -646,24 +649,51 @@
#undef STATE_GHASH_HAS_MORE
}
-static __m128i expand_key(__m128i key, __m128i t)
+static __m128i expand_key(__m128i key, __m128i temp)
{
- t = _mm_shuffle_epi32(t, _MM_SHUFFLE(3, 3, 3, 3));
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
- return _mm_xor_si128(key, t);
+
+ key = _mm_xor_si128(key, temp);
+
+ return key;
}
-void ptls_fusion_aesecb_init(ptls_fusion_aesecb_context_t *ctx, const void *key)
+void ptls_fusion_aesecb_init(ptls_fusion_aesecb_context_t *ctx, int is_enc, const void *key, size_t key_size)
{
+ assert(is_enc && "decryption is not supported (yet)");
+
size_t i = 0;
+ switch (key_size) {
+ case 16: /* AES128 */
+ ctx->rounds = 10;
+ break;
+ case 32: /* AES256 */
+ ctx->rounds = 14;
+ break;
+ default:
+ assert(!"invalid key size; AES128 / AES256 are supported");
+ break;
+ }
+
ctx->keys[i++] = _mm_loadu_si128((__m128i *)key);
+ if (key_size == 32)
+ ctx->keys[i++] = _mm_loadu_si128((__m128i *)key + 1);
+
#define EXPAND(R) \
do { \
- ctx->keys[i] = expand_key(ctx->keys[i - 1], _mm_aeskeygenassist_si128(ctx->keys[i - 1], R)); \
+ ctx->keys[i] = expand_key(ctx->keys[i - key_size / 16], \
+ _mm_shuffle_epi32(_mm_aeskeygenassist_si128(ctx->keys[i - 1], R), _MM_SHUFFLE(3, 3, 3, 3))); \
+ if (i == ctx->rounds) \
+ goto Done; \
++i; \
+ if (key_size > 24) { \
+ ctx->keys[i] = expand_key(ctx->keys[i - key_size / 16], \
+ _mm_shuffle_epi32(_mm_aeskeygenassist_si128(ctx->keys[i - 1], R), _MM_SHUFFLE(2, 2, 2, 2))); \
+ ++i; \
+ } \
} while (0)
EXPAND(0x1);
EXPAND(0x2);
@@ -676,6 +706,8 @@
EXPAND(0x1b);
EXPAND(0x36);
#undef EXPAND
+Done:
+ assert(i == ctx->rounds);
}
void ptls_fusion_aesecb_dispose(ptls_fusion_aesecb_context_t *ctx)
@@ -683,7 +715,14 @@
ptls_clear_memory(ctx, sizeof(*ctx));
}
-ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t max_size)
+void ptls_fusion_aesecb_encrypt(ptls_fusion_aesecb_context_t *ctx, void *dst, const void *src)
+{
+ __m128i v = _mm_loadu_si128(src);
+ v = aesecb_encrypt(ctx, v);
+ _mm_storeu_si128(dst, v);
+}
+
+ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t key_size, size_t max_size)
{
ptls_fusion_aesgcm_context_t *ctx;
size_t ghash_cnt = (max_size + 15) / 16 + 2; // round-up by block size, add to handle worst split of the size between AAD and
@@ -692,7 +731,7 @@
if ((ctx = malloc(sizeof(*ctx) + sizeof(ctx->ghash[0]) * ghash_cnt)) == NULL)
return NULL;
- ptls_fusion_aesecb_init(&ctx->ecb, key);
+ ptls_fusion_aesecb_init(&ctx->ecb, 1, key, key_size);
ctx->ghash_cnt = ghash_cnt;
ctx->ghash[0].H = aesecb_encrypt(&ctx->ecb, _mm_setzero_si128());
@@ -754,7 +793,7 @@
ctx->super.do_dispose = ctr_dispose;
ctx->super.do_init = ctr_init;
ctx->super.do_transform = ctr_transform;
- ptls_fusion_aesecb_init(&ctx->fusion, key);
+ ptls_fusion_aesecb_init(&ctx->fusion, 1, key, PTLS_AES128_KEY_SIZE);
ctx->is_ready = 0;
return 0;
@@ -830,7 +869,8 @@
ctx->super.do_encrypt = aead_do_encrypt;
ctx->super.do_decrypt = aead_do_decrypt;
- ctx->aesgcm = ptls_fusion_aesgcm_new(key, 1500); /* FIXME use realloc with exponential back-off to support arbitrary size */
+ ctx->aesgcm = ptls_fusion_aesgcm_new(key, PTLS_AES128_KEY_SIZE,
+ 1500); /* FIXME use realloc with exponential back-off to support arbitrary size */
return 0;
}
diff --git a/t/fusion.c b/t/fusion.c
index 34dff0d..f9ebd2d 100644
--- a/t/fusion.c
+++ b/t/fusion.c
@@ -56,10 +56,25 @@
static const uint8_t zero[16384] = {}, one[16] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
{
+ ptls_fusion_aesecb_context_t ecb;
+ uint8_t encrypted[16];
+
+ ptls_fusion_aesecb_init(&ecb, 1, zero, 16);
+ ptls_fusion_aesecb_encrypt(&ecb, encrypted, "hello world!!!!!");
+ ptls_fusion_aesecb_dispose(&ecb);
+ ok(strcmp(tostr(encrypted, 16), "172afecb50b5f1237814b2f7cb51d0f7") == 0);
+
+ ptls_fusion_aesecb_init(&ecb, 1, zero, 32);
+ ptls_fusion_aesecb_encrypt(&ecb, encrypted, "hello world!!!!!");
+ ptls_fusion_aesecb_dispose(&ecb);
+ ok(strcmp(tostr(encrypted, 16), "2a033f0627b3554aa4fe5786550736ff") == 0);
+ }
+
+ {
static const uint8_t expected[] = {0x03, 0x88, 0xda, 0xce, 0x60, 0xb6, 0xa3, 0x92, 0xf3, 0x28, 0xc2,
0xb9, 0x71, 0xb2, 0xfe, 0x78, 0x97, 0x3f, 0xbc, 0xa6, 0x54, 0x77,
0xbf, 0x47, 0x85, 0xb0, 0xd5, 0x61, 0xf7, 0xe3, 0xfd, 0x6c};
- ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(zero, 5 + 16);
+ ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(zero, PTLS_AES128_KEY_SIZE, 5 + 16);
uint8_t encrypted[sizeof(expected)], decrypted[sizeof(expected) - 16];
ptls_fusion_aesgcm_encrypt(ctx, encrypted, zero, 16, _mm_setzero_si128(), "hello", 5, NULL);
@@ -75,7 +90,7 @@
{ /* test capacity */
static const uint8_t expected[17] = {0x5b, 0x27, 0x21, 0x5e, 0xd8, 0x1a, 0x70, 0x2e, 0x39,
0x41, 0xc8, 0x05, 0x77, 0xd5, 0x2f, 0xcb, 0x57};
- ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(zero, 2);
+ ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(zero, PTLS_AES128_KEY_SIZE, 2);
uint8_t encrypted[17], decrypted[1] = {0x55};
ptls_fusion_aesgcm_encrypt(ctx, encrypted, "X", 1, _mm_setzero_si128(), "a", 1, NULL);
ok(memcmp(expected, encrypted, 17) == 0);
@@ -85,7 +100,7 @@
}
{
- ptls_fusion_aesgcm_context_t *aead = ptls_fusion_aesgcm_new(zero, sizeof(zero));
+ ptls_fusion_aesgcm_context_t *aead = ptls_fusion_aesgcm_new(zero, PTLS_AES128_KEY_SIZE, sizeof(zero));
ptls_aead_supplementary_encryption_t *supp = NULL;
for (int i = 0; i < 2; ++i) {
diff --git a/t/fusionbench.c b/t/fusionbench.c
index b7a7d3b..0599fef 100644
--- a/t/fusionbench.c
+++ b/t/fusionbench.c
@@ -53,7 +53,7 @@
if (supp != NULL)
supp->input = textlen >= 2 ? text + 2 : text + textlen;
- ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(key, sizeof(aad) + textlen);
+ ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(key, sizeof(key), sizeof(aad) + textlen);
if (!decrypt) {
for (int i = 0; i < count; ++i)