AES256
diff --git a/include/picotls/fusion.h b/include/picotls/fusion.h
index 786df27..85cdc0d 100644
--- a/include/picotls/fusion.h
+++ b/include/picotls/fusion.h
@@ -30,23 +30,26 @@
 #include <emmintrin.h>
 #include "../picotls.h"
 
-#define PTLS_FUSION_AES_ROUNDS 10 /* TODO support AES256 */
+#define PTLS_FUSION_AES128_ROUNDS 10
+#define PTLS_FUSION_AES256_ROUNDS 14
 
 typedef struct ptls_fusion_aesecb_context {
-    __m128i keys[PTLS_FUSION_AES_ROUNDS + 1];
+    __m128i keys[PTLS_FUSION_AES256_ROUNDS + 1];
+    unsigned rounds;
 } ptls_fusion_aesecb_context_t;
 
 typedef struct ptls_fusion_aesgcm_context ptls_fusion_aesgcm_context_t;
 
-void ptls_fusion_aesecb_init(ptls_fusion_aesecb_context_t *ctx, const void *key);
+void ptls_fusion_aesecb_init(ptls_fusion_aesecb_context_t *ctx, int is_enc, const void *key, size_t key_size);
 void ptls_fusion_aesecb_dispose(ptls_fusion_aesecb_context_t *ctx);
+void ptls_fusion_aesecb_encrypt(ptls_fusion_aesecb_context_t *ctx, void *dst, const void *src);
 
 /**
  * Creates an AES-GCM context.
  * @param key       the AES key (128 bits)
  * @param max_size  maximum size of the record (i.e. AAD + encrypted payload)
  */
-ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t max_size);
+ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t key_size, size_t max_size);
 /**
  * Destroys an AES-GCM context.
  */
diff --git a/lib/fusion.c b/lib/fusion.c
index 70b229a..8be2bcb 100644
--- a/lib/fusion.c
+++ b/lib/fusion.c
@@ -187,7 +187,7 @@
     size_t i;
 
     v = _mm_xor_si128(v, ctx->keys[0]);
-    for (i = 1; i < PTLS_FUSION_AES_ROUNDS; ++i)
+    for (i = 1; i < ctx->rounds; ++i)
         v = _mm_aesenc_si128(v, ctx->keys[i]);
     v = _mm_aesenclast_si128(v, ctx->keys[i]);
 
@@ -272,14 +272,14 @@
     } while (0)
 
 /* aesenclast */
-#define AESECB6_FINAL()                                                                                                            \
+#define AESECB6_FINAL(i)                                                                                                           \
     do {                                                                                                                           \
-        __m128i k = ctx->ecb.keys[10];                                                                                             \
+        __m128i k = ctx->ecb.keys[i];                                                                                              \
         bits0 = _mm_aesenclast_si128(bits0, k);                                                                                    \
         bits1 = _mm_aesenclast_si128(bits1, k);                                                                                    \
         bits2 = _mm_aesenclast_si128(bits2, k);                                                                                    \
         bits3 = _mm_aesenclast_si128(bits3, k);                                                                                    \
-        bits4 = _mm_aesenclast_si128(bits4, bits4keys[10]);                                                                        \
+        bits4 = _mm_aesenclast_si128(bits4, bits4keys[i]);                                                                         \
         bits5 = _mm_aesenclast_si128(bits5, k);                                                                                    \
     } while (0)
 
@@ -313,11 +313,13 @@
     ctr = _mm_insert_epi32(ctr, 1, 0);
     ek0 = _mm_shuffle_epi8(ctr, bswap8);
 
-    /* prepare the first bit stream */
-    AESECB6_INIT();
-    for (size_t i = 1; i < 10; ++i)
-        AESECB6_UPDATE(i);
-    AESECB6_FINAL();
+    { /* prepare the first bit stream */
+        size_t i;
+        AESECB6_INIT();
+        for (i = 1; i < ctx->ecb.rounds; ++i)
+            AESECB6_UPDATE(i);
+        AESECB6_FINAL(i);
+    }
 
     /* the main loop */
     while (1) {
@@ -416,13 +418,14 @@
         }
 
         /* run AES and multiplication in parallel */
-        for (size_t i = 2; i <= 7; ++i) {
+        size_t i;
+        for (i = 2; i <= 7; ++i) {
             AESECB6_UPDATE(i);
             gfmul_onestep(&gstate, _mm_loadu_si128(gdata++), --ghash_precompute);
         }
-        AESECB6_UPDATE(8);
-        AESECB6_UPDATE(9);
-        AESECB6_FINAL();
+        for (; i < ctx->ecb.rounds; ++i)
+            AESECB6_UPDATE(i);
+        AESECB6_FINAL(i);
     }
 
 Finish:
@@ -450,8 +453,8 @@
         }
         do {
             bits4 = _mm_aesenc_si128(bits4, bits4keys[i++]);
-        } while (i != 10);
-        bits4 = _mm_aesenclast_si128(bits4, bits4keys[10]);
+        } while (i != ctx->ecb.rounds);
+        bits4 = _mm_aesenclast_si128(bits4, bits4keys[i]);
         _mm_storeu_si128((__m128i *)supp->output, bits4);
     }
 
@@ -571,7 +574,7 @@
                 AESECB6_UPDATE(aesi);
                 gfmul_onestep(&gstate, _mm_loadu_si128(gdata++), --ghash_precompute);
             }
-            for (; aesi <= 9; ++aesi)
+            for (; aesi < ctx->ecb.rounds; ++aesi)
                 AESECB6_UPDATE(aesi);
             __m128i k = ctx->ecb.keys[aesi];
             bits0 = _mm_aesenclast_si128(bits0, k);
@@ -646,24 +649,51 @@
 #undef STATE_GHASH_HAS_MORE
 }
 
-static __m128i expand_key(__m128i key, __m128i t)
+static __m128i expand_key(__m128i key, __m128i temp)
 {
-    t = _mm_shuffle_epi32(t, _MM_SHUFFLE(3, 3, 3, 3));
     key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
     key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
     key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
-    return _mm_xor_si128(key, t);
+
+    key = _mm_xor_si128(key, temp);
+
+    return key;
 }
 
-void ptls_fusion_aesecb_init(ptls_fusion_aesecb_context_t *ctx, const void *key)
+void ptls_fusion_aesecb_init(ptls_fusion_aesecb_context_t *ctx, int is_enc, const void *key, size_t key_size)
 {
+    assert(is_enc && "decryption is not supported (yet)");
+
     size_t i = 0;
 
+    switch (key_size) {
+    case 16: /* AES128 */
+        ctx->rounds = 10;
+        break;
+    case 32: /* AES256 */
+        ctx->rounds = 14;
+        break;
+    default:
+        assert(!"invalid key size; AES128 / AES256 are supported");
+        break;
+    }
+
     ctx->keys[i++] = _mm_loadu_si128((__m128i *)key);
+    if (key_size == 32)
+        ctx->keys[i++] = _mm_loadu_si128((__m128i *)key + 1);
+
 #define EXPAND(R)                                                                                                                  \
     do {                                                                                                                           \
-        ctx->keys[i] = expand_key(ctx->keys[i - 1], _mm_aeskeygenassist_si128(ctx->keys[i - 1], R));                               \
+        ctx->keys[i] = expand_key(ctx->keys[i - key_size / 16],                                                                    \
+                                  _mm_shuffle_epi32(_mm_aeskeygenassist_si128(ctx->keys[i - 1], R), _MM_SHUFFLE(3, 3, 3, 3)));     \
+        if (i == ctx->rounds)                                                                                                      \
+            goto Done;                                                                                                             \
         ++i;                                                                                                                       \
+        if (key_size > 24) {                                                                                                       \
+            ctx->keys[i] = expand_key(ctx->keys[i - key_size / 16],                                                                \
+                                      _mm_shuffle_epi32(_mm_aeskeygenassist_si128(ctx->keys[i - 1], R), _MM_SHUFFLE(2, 2, 2, 2))); \
+            ++i;                                                                                                                   \
+        }                                                                                                                          \
     } while (0)
     EXPAND(0x1);
     EXPAND(0x2);
@@ -676,6 +706,8 @@
     EXPAND(0x1b);
     EXPAND(0x36);
 #undef EXPAND
+Done:
+    assert(i == ctx->rounds);
 }
 
 void ptls_fusion_aesecb_dispose(ptls_fusion_aesecb_context_t *ctx)
@@ -683,7 +715,14 @@
     ptls_clear_memory(ctx, sizeof(*ctx));
 }
 
-ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t max_size)
+void ptls_fusion_aesecb_encrypt(ptls_fusion_aesecb_context_t *ctx, void *dst, const void *src)
+{
+    __m128i v = _mm_loadu_si128(src);
+    v = aesecb_encrypt(ctx, v);
+    _mm_storeu_si128(dst, v);
+}
+
+ptls_fusion_aesgcm_context_t *ptls_fusion_aesgcm_new(const void *key, size_t key_size, size_t max_size)
 {
     ptls_fusion_aesgcm_context_t *ctx;
     size_t ghash_cnt = (max_size + 15) / 16 + 2; // round-up by block size, add to handle worst split of the size between AAD and
@@ -692,7 +731,7 @@
     if ((ctx = malloc(sizeof(*ctx) + sizeof(ctx->ghash[0]) * ghash_cnt)) == NULL)
         return NULL;
 
-    ptls_fusion_aesecb_init(&ctx->ecb, key);
+    ptls_fusion_aesecb_init(&ctx->ecb, 1, key, key_size);
 
     ctx->ghash_cnt = ghash_cnt;
     ctx->ghash[0].H = aesecb_encrypt(&ctx->ecb, _mm_setzero_si128());
@@ -754,7 +793,7 @@
     ctx->super.do_dispose = ctr_dispose;
     ctx->super.do_init = ctr_init;
     ctx->super.do_transform = ctr_transform;
-    ptls_fusion_aesecb_init(&ctx->fusion, key);
+    ptls_fusion_aesecb_init(&ctx->fusion, 1, key, PTLS_AES128_KEY_SIZE);
     ctx->is_ready = 0;
 
     return 0;
@@ -830,7 +869,8 @@
     ctx->super.do_encrypt = aead_do_encrypt;
     ctx->super.do_decrypt = aead_do_decrypt;
 
-    ctx->aesgcm = ptls_fusion_aesgcm_new(key, 1500); /* FIXME use realloc with exponential back-off to support arbitrary size */
+    ctx->aesgcm = ptls_fusion_aesgcm_new(key, PTLS_AES128_KEY_SIZE,
+                                         1500); /* FIXME use realloc with exponential back-off to support arbitrary size */
 
     return 0;
 }
diff --git a/t/fusion.c b/t/fusion.c
index 34dff0d..f9ebd2d 100644
--- a/t/fusion.c
+++ b/t/fusion.c
@@ -56,10 +56,25 @@
     static const uint8_t zero[16384] = {}, one[16] = {1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
 
     {
+        ptls_fusion_aesecb_context_t ecb;
+        uint8_t encrypted[16];
+
+        ptls_fusion_aesecb_init(&ecb, 1, zero, 16);
+        ptls_fusion_aesecb_encrypt(&ecb, encrypted, "hello world!!!!!");
+        ptls_fusion_aesecb_dispose(&ecb);
+        ok(strcmp(tostr(encrypted, 16), "172afecb50b5f1237814b2f7cb51d0f7") == 0);
+
+        ptls_fusion_aesecb_init(&ecb, 1, zero, 32);
+        ptls_fusion_aesecb_encrypt(&ecb, encrypted, "hello world!!!!!");
+        ptls_fusion_aesecb_dispose(&ecb);
+        ok(strcmp(tostr(encrypted, 16), "2a033f0627b3554aa4fe5786550736ff") == 0);
+    }
+
+    {
         static const uint8_t expected[] = {0x03, 0x88, 0xda, 0xce, 0x60, 0xb6, 0xa3, 0x92, 0xf3, 0x28, 0xc2,
                                            0xb9, 0x71, 0xb2, 0xfe, 0x78, 0x97, 0x3f, 0xbc, 0xa6, 0x54, 0x77,
                                            0xbf, 0x47, 0x85, 0xb0, 0xd5, 0x61, 0xf7, 0xe3, 0xfd, 0x6c};
-        ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(zero, 5 + 16);
+        ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(zero, PTLS_AES128_KEY_SIZE, 5 + 16);
         uint8_t encrypted[sizeof(expected)], decrypted[sizeof(expected) - 16];
 
         ptls_fusion_aesgcm_encrypt(ctx, encrypted, zero, 16, _mm_setzero_si128(), "hello", 5, NULL);
@@ -75,7 +90,7 @@
     { /* test capacity */
         static const uint8_t expected[17] = {0x5b, 0x27, 0x21, 0x5e, 0xd8, 0x1a, 0x70, 0x2e, 0x39,
                                              0x41, 0xc8, 0x05, 0x77, 0xd5, 0x2f, 0xcb, 0x57};
-        ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(zero, 2);
+        ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(zero, PTLS_AES128_KEY_SIZE, 2);
         uint8_t encrypted[17], decrypted[1] = {0x55};
         ptls_fusion_aesgcm_encrypt(ctx, encrypted, "X", 1, _mm_setzero_si128(), "a", 1, NULL);
         ok(memcmp(expected, encrypted, 17) == 0);
@@ -85,7 +100,7 @@
     }
 
     {
-        ptls_fusion_aesgcm_context_t *aead = ptls_fusion_aesgcm_new(zero, sizeof(zero));
+        ptls_fusion_aesgcm_context_t *aead = ptls_fusion_aesgcm_new(zero, PTLS_AES128_KEY_SIZE, sizeof(zero));
         ptls_aead_supplementary_encryption_t *supp = NULL;
 
         for (int i = 0; i < 2; ++i) {
diff --git a/t/fusionbench.c b/t/fusionbench.c
index b7a7d3b..0599fef 100644
--- a/t/fusionbench.c
+++ b/t/fusionbench.c
@@ -53,7 +53,7 @@
     if (supp != NULL)
         supp->input = textlen >= 2 ? text + 2 : text + textlen;
 
-    ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(key, sizeof(aad) + textlen);
+    ptls_fusion_aesgcm_context_t *ctx = ptls_fusion_aesgcm_new(key, sizeof(key), sizeof(aad) + textlen);
 
     if (!decrypt) {
         for (int i = 0; i < count; ++i)