auto-expand
diff --git a/include/picotls/fusion.h b/include/picotls/fusion.h
index c11de99..332dd93 100644
--- a/include/picotls/fusion.h
+++ b/include/picotls/fusion.h
@@ -47,9 +47,13 @@
/**
* Creates an AES-GCM context.
* @param key the AES key (128 bits)
- * @param max_size maximum size of the record (i.e. AAD + encrypted payload)
+ * @param capacity maximum size of AEAD record (i.e. AAD + encrypted payload)
*/
-ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t key_size, size_t max_size);
+ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t key_size, size_t capacity);
+/**
+ * Updates the capacity.
+ */
+ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_set_capacity(ptls_fusion_aesgcm_context_t *ctx, size_t capacity);
/**
* Destroys an AES-GCM context.
*/
diff --git a/lib/fusion.c b/lib/fusion.c
index adff45e..1bd40b9 100644
--- a/lib/fusion.c
+++ b/lib/fusion.c
@@ -49,6 +49,7 @@
struct ptls_fusion_aesgcm_context {
ptls_fusion_aesecb_context_t ecb;
+ size_t capacity;
size_t ghash_cnt;
struct ptls_fusion_aesgcm_ghash_precompute {
__m128i H;
@@ -722,29 +723,62 @@
_mm_storeu_si128(dst, v);
}
-ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t key_size, size_t max_size)
+/**
+ * returns the number of ghash entries that is required to handle an AEAD block of given size
+ */
+static size_t aesgcm_calc_ghash_cnt(size_t capacity)
+{
+ // round-up by block size, add to handle worst split of the size between AAD and payload, plus context to hash AC
+ return (capacity + 15) / 16 + 2;
+}
+
+static void setup_one_ghash_entry(ptls_fusion_aesgcm_context_t *ctx)
+{
+ if (ctx->ghash_cnt != 0)
+ ctx->ghash[ctx->ghash_cnt].H = gfmul(ctx->ghash[ctx->ghash_cnt - 1].H, ctx->ghash[0].H);
+
+ __m128i r = _mm_shuffle_epi32(ctx->ghash[ctx->ghash_cnt].H, 78);
+ r = _mm_xor_si128(r, ctx->ghash[ctx->ghash_cnt].H);
+ ctx->ghash[ctx->ghash_cnt].r = r;
+
+ ++ctx->ghash_cnt;
+}
+
+ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t key_size, size_t capacity)
{
ptls_fusion_aesgcm_context_t *ctx;
- size_t ghash_cnt = (max_size + 15) / 16 + 2; // round-up by block size, add to handle worst split of the size between AAD and
- // payload, plus context to hash AC
+ size_t ghash_cnt = aesgcm_calc_ghash_cnt(capacity);
if ((ctx = malloc(sizeof(*ctx) + sizeof(ctx->ghash[0]) * ghash_cnt)) == NULL)
return NULL;
ptls_fusion_aesecb_init(&ctx->ecb, 1, key, key_size);
- ctx->ghash_cnt = ghash_cnt;
+ ctx->capacity = capacity;
+
ctx->ghash[0].H = aesecb_encrypt(&ctx->ecb, _mm_setzero_si128());
ctx->ghash[0].H = _mm_shuffle_epi8(ctx->ghash[0].H, bswap8);
-
ctx->ghash[0].H = transformH(ctx->ghash[0].H);
- for (int i = 1; i < ghash_cnt; ++i)
- ctx->ghash[i].H = gfmul(ctx->ghash[i - 1].H, ctx->ghash[0].H);
- for (int i = 0; i < ghash_cnt; ++i) {
- __m128i r = _mm_shuffle_epi32(ctx->ghash[i].H, 78);
- r = _mm_xor_si128(r, ctx->ghash[i].H);
- ctx->ghash[i].r = r;
- }
+ ctx->ghash_cnt = 0;
+ while (ctx->ghash_cnt < ghash_cnt)
+ setup_one_ghash_entry(ctx);
+
+ return ctx;
+}
+
+ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_set_capacity(ptls_fusion_aesgcm_context_t *ctx, size_t capacity)
+{
+ size_t ghash_cnt = aesgcm_calc_ghash_cnt(capacity);
+
+ if (ghash_cnt <= ctx->ghash_cnt)
+ return ctx;
+
+ if ((ctx = realloc(ctx, sizeof(*ctx) + sizeof(ctx->ghash[0]) * ghash_cnt)) == NULL)
+ return NULL;
+
+ ctx->capacity = capacity;
+ while (ghash_cnt < ctx->ghash_cnt)
+ setup_one_ghash_entry(ctx);
return ctx;
}
@@ -847,6 +881,8 @@
{
struct aesgcm_context *ctx = (void *)_ctx;
+ if (inlen + aadlen > ctx->aesgcm->capacity)
+ ctx->aesgcm = ptls_fusion_aesgcm_set_capacity(ctx->aesgcm, inlen + aadlen);
ptls_fusion_aesgcm_encrypt(ctx->aesgcm, output, input, inlen, calc_counter(ctx, seq), aad, aadlen, supp);
}
@@ -859,6 +895,8 @@
return SIZE_MAX;
size_t enclen = inlen - 16;
+ if (enclen + aadlen > ctx->aesgcm->capacity)
+ ctx->aesgcm = ptls_fusion_aesgcm_set_capacity(ctx->aesgcm, enclen + aadlen);
if (!ptls_fusion_aesgcm_decrypt(ctx->aesgcm, output, input, enclen, calc_counter(ctx, seq), aad, aadlen,
(const uint8_t *)input + enclen))
return SIZE_MAX;