Merge pull request #331 from h2o/iv-96

Add support for xor_iv
diff --git a/include/picotls.h b/include/picotls.h
index 0ffe7d7..90a3e53 100644
--- a/include/picotls.h
+++ b/include/picotls.h
@@ -324,6 +324,7 @@
     const struct st_ptls_aead_algorithm_t *algo;
     /* field above this line must not be altered by the crypto binding */
     void (*dispose_crypto)(struct st_ptls_aead_context_t *ctx);
+    void (*do_xor_iv)(struct st_ptls_aead_context_t *ctx, const void * xor, size_t xorlen);
     void (*do_encrypt_init)(struct st_ptls_aead_context_t *ctx, uint64_t seq, const void *aad, size_t aadlen);
     size_t (*do_encrypt_update)(struct st_ptls_aead_context_t *ctx, void *output, const void *input, size_t inlen);
     size_t (*do_encrypt_final)(struct st_ptls_aead_context_t *ctx, void *output);
@@ -1229,6 +1230,11 @@
  */
 void ptls_aead_free(ptls_aead_context_t *ctx);
 /**
+ * Permutes the static IV by applying given bytes using bit-wise XOR. This API can be used for supplying nonces longer than 64-
+ * bits.
+ */
+static void ptls_aead_xor_iv(ptls_aead_context_t *ctx, const void *bytes, size_t len); 
+/**
  *
  */
 static size_t ptls_aead_encrypt(ptls_aead_context_t *ctx, void *output, const void *input, size_t inlen, uint64_t seq,
@@ -1413,6 +1419,11 @@
     ctx->do_transform(ctx, output, input, len);
 }
 
+inline void ptls_aead_xor_iv(ptls_aead_context_t *ctx, const void *bytes, size_t len)
+{
+    ctx->do_xor_iv(ctx, bytes, len);
+}
+
 inline size_t ptls_aead_encrypt(ptls_aead_context_t *ctx, void *output, const void *input, size_t inlen, uint64_t seq,
                                 const void *aad, size_t aadlen)
 {
diff --git a/lib/cifra/aes-common.h b/lib/cifra/aes-common.h
index 0c393c5..dfb620a 100644
--- a/lib/cifra/aes-common.h
+++ b/lib/cifra/aes-common.h
@@ -154,11 +154,21 @@
     return tag_offset;
 }
 
+static inline void aesgcm_xor_iv(ptls_aead_context_t *_ctx, const void *_bytes, size_t len)
+{
+    struct aesgcm_context_t *ctx = (struct aesgcm_context_t *)_ctx;
+    const uint8_t *bytes = _bytes;
+
+    for (size_t i = 0; i < len; ++i)
+        ctx->static_iv[i] ^= bytes[i];
+}
+
 static inline int aead_aesgcm_setup_crypto(ptls_aead_context_t *_ctx, int is_enc, const void *key, const void *iv)
 {
     struct aesgcm_context_t *ctx = (struct aesgcm_context_t *)_ctx;
 
     ctx->super.dispose_crypto = aesgcm_dispose_crypto;
+    ctx->super.do_xor_iv = aesgcm_xor_iv;
     if (is_enc) {
         ctx->super.do_encrypt_init = aesgcm_encrypt_init;
         ctx->super.do_encrypt_update = aesgcm_encrypt_update;
diff --git a/lib/cifra/chacha20.c b/lib/cifra/chacha20.c
index 57a6c39..fe19c32 100644
--- a/lib/cifra/chacha20.c
+++ b/lib/cifra/chacha20.c
@@ -179,11 +179,21 @@
     return ret;
 }
 
+static void chacha20poly1305_xor_iv(ptls_aead_context_t *_ctx, const void *_bytes, size_t len)
+{
+    struct chacha20poly1305_context_t *ctx = (struct chacha20poly1305_context_t *)_ctx;
+    const uint8_t *bytes = _bytes;
+
+    for (size_t i = 0; i < len; ++i)
+        ctx->static_iv[i] ^= bytes[i];
+}
+
 static int aead_chacha20poly1305_setup_crypto(ptls_aead_context_t *_ctx, int is_enc, const void *key, const void *iv)
 {
     struct chacha20poly1305_context_t *ctx = (struct chacha20poly1305_context_t *)_ctx;
 
     ctx->super.dispose_crypto = chacha20poly1305_dispose_crypto;
+    ctx->super.do_xor_iv = chacha20poly1305_xor_iv;
     if (is_enc) {
         ctx->super.do_encrypt_init = chacha20poly1305_init;
         ctx->super.do_encrypt_update = chacha20poly1305_encrypt_update;
diff --git a/lib/fusion.c b/lib/fusion.c
index 1610447..d6562b4 100644
--- a/lib/fusion.c
+++ b/lib/fusion.c
@@ -924,6 +924,14 @@
     return enclen;
 }
 
+static inline void aesgcm_xor_iv(ptls_aead_context_t *_ctx, const void *_bytes, size_t len)
+{
+    struct aesgcm_context *ctx = (struct aesgcm_context *)_ctx;
+    __m128i xor_mask = loadn(_bytes, len);
+    xor_mask = _mm_shuffle_epi8(xor_mask, bswap8);
+    ctx->static_iv = _mm_xor_si128(ctx->static_iv, xor_mask);
+}
+
 static int aesgcm_setup(ptls_aead_context_t *_ctx, int is_enc, const void *key, const void *iv, size_t key_size)
 {
     struct aesgcm_context *ctx = (struct aesgcm_context *)_ctx;
@@ -934,6 +942,7 @@
         return 0;
 
     ctx->super.dispose_crypto = aesgcm_dispose_crypto;
+    ctx->super.do_xor_iv = aesgcm_xor_iv;
     ctx->super.do_encrypt_init = aead_do_encrypt_init;
     ctx->super.do_encrypt_update = aead_do_encrypt_update;
     ctx->super.do_encrypt_final = aead_do_encrypt_final;
diff --git a/lib/openssl.c b/lib/openssl.c
index c7ba115..00c7777 100644
--- a/lib/openssl.c
+++ b/lib/openssl.c
@@ -779,6 +779,15 @@
         EVP_CIPHER_CTX_free(ctx->evp_ctx);
 }
 
+static void aead_xor_iv(ptls_aead_context_t *_ctx, const void *_bytes, size_t len)
+{
+    struct aead_crypto_context_t *ctx = (struct aead_crypto_context_t *)_ctx;
+    const uint8_t *bytes = _bytes;
+
+    for (size_t i = 0; i < len; ++i)
+        ctx->static_iv[i] ^= bytes[i];
+}
+
 static void aead_do_encrypt_init(ptls_aead_context_t *_ctx, uint64_t seq, const void *aad, size_t aadlen)
 {
     struct aead_crypto_context_t *ctx = (struct aead_crypto_context_t *)_ctx;
@@ -864,6 +873,7 @@
         return 0;
 
     ctx->super.dispose_crypto = aead_dispose_crypto;
+    ctx->super.do_xor_iv = aead_xor_iv;
     if (is_enc) {
         ctx->super.do_encrypt_init = aead_do_encrypt_init;
         ctx->super.do_encrypt_update = aead_do_encrypt_update;
diff --git a/lib/ptlsbcrypt.c b/lib/ptlsbcrypt.c
index 9f46a47..43c442f 100644
--- a/lib/ptlsbcrypt.c
+++ b/lib/ptlsbcrypt.c
@@ -476,6 +476,15 @@
     }
 }
 
+static void ptls_bcrypt_aead_xor_iv(ptls_aead_context_t *_ctx, const void *_bytes, size_t len)
+{
+    struct ptls_bcrypt_aead_context_t *ctx = (struct ptls_bcrypt_aead_context_t *)_ctx;
+    const uint8_t *bytes = _bytes;
+
+    for (size_t i = 0; i < len; ++i)
+        ctx->bctx.iv[i] ^= bytes[i];
+}
+
 static int ptls_bcrypt_aead_setup_crypto(ptls_aead_context_t *_ctx, int is_enc, const void *key, const void *iv,
                                          wchar_t const *bcrypt_name, wchar_t const *bcrypt_mode, size_t bcrypt_mode_size)
 {
@@ -530,6 +539,7 @@
         memcpy(ctx->bctx.iv_static, iv, ctx->super.algo->iv_size);
         if (is_enc) {
             ctx->super.dispose_crypto = ptls_bcrypt_aead_dispose_crypto;
+            ctx->super.do_xor_iv = ptls_bcrypt_aead_xor_iv;
             ctx->super.do_decrypt = NULL;
             ctx->super.do_encrypt_init = ptls_bcrypt_aead_do_encrypt_init;
             ctx->super.do_encrypt_update = ptls_bcrypt_aead_do_encrypt_update;
diff --git a/picotlsvs/picotlsvs.sln b/picotlsvs/picotlsvs.sln
index 4f79880..6be9680 100644
--- a/picotlsvs/picotlsvs.sln
+++ b/picotlsvs/picotlsvs.sln
@@ -155,13 +155,13 @@
 		{7A04A2BF-03BF-4C3A-9E44-A53B0E90036B}.Release|x86.Build.0 = Release|Win32
 		{55F22DE6-EAAE-4279-97B7-84FAAB7F29BB}.Debug|x64.ActiveCfg = Debug|x64
 		{55F22DE6-EAAE-4279-97B7-84FAAB7F29BB}.Debug|x64.Build.0 = Debug|x64
-		{55F22DE6-EAAE-4279-97B7-84FAAB7F29BB}.Debug|x86.ActiveCfg = Debug|Win32
+		{55F22DE6-EAAE-4279-97B7-84FAAB7F29BB}.Debug|x86.ActiveCfg = Debug|x64
 		{55F22DE6-EAAE-4279-97B7-84FAAB7F29BB}.Release|x64.ActiveCfg = Release|x64
 		{55F22DE6-EAAE-4279-97B7-84FAAB7F29BB}.Release|x64.Build.0 = Release|x64
 		{55F22DE6-EAAE-4279-97B7-84FAAB7F29BB}.Release|x86.ActiveCfg = Release|Win32
 		{513DC1C9-2CD4-437C-B8EC-168924EB5AAC}.Debug|x64.ActiveCfg = Debug|x64
 		{513DC1C9-2CD4-437C-B8EC-168924EB5AAC}.Debug|x64.Build.0 = Debug|x64
-		{513DC1C9-2CD4-437C-B8EC-168924EB5AAC}.Debug|x86.ActiveCfg = Debug|Win32
+		{513DC1C9-2CD4-437C-B8EC-168924EB5AAC}.Debug|x86.ActiveCfg = Debug|x64
 		{513DC1C9-2CD4-437C-B8EC-168924EB5AAC}.Release|x64.ActiveCfg = Release|x64
 		{513DC1C9-2CD4-437C-B8EC-168924EB5AAC}.Release|x64.Build.0 = Release|x64
 		{513DC1C9-2CD4-437C-B8EC-168924EB5AAC}.Release|x86.ActiveCfg = Release|Win32
diff --git a/t/fusion.c b/t/fusion.c
index a7a9685..f9c61bb 100644
--- a/t/fusion.c
+++ b/t/fusion.c
@@ -19,6 +19,7 @@
  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
  * IN THE SOFTWARE.
  */
+
 #include <assert.h>
 #include <stdio.h>
 #include <string.h>
@@ -50,7 +51,7 @@
 
 static void test_loadn(void)
 {
-    uint8_t buf[8192] = { 0 };
+    uint8_t buf[8192] = {0};
 
     for (size_t off = 0; off < 8192 - 15; ++off) {
         uint8_t *src = buf + off;
@@ -65,7 +66,7 @@
     ok(!!"success");
 }
 
-static const uint8_t zero[16384] = { 0 };
+static const uint8_t zero[16384] = {0};
 
 static void test_ecb(void)
 {
@@ -193,21 +194,62 @@
     ptls_fusion_aesgcm_free(aead);
 }
 
-static void test_generated(int aes256)
+static void gcm_iv96(void)
+{
+    static const uint8_t key[16] = {0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff},
+                         aad[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19},
+                         iv[] = {20, 20, 20, 20, 24, 25, 26, 27, 28, 29, 30, 31},
+                         plaintext[] =
+                             "hello world\nhello world\nhello world\nhello world\nhello world\nhello world\nhello world\n";
+    static const uint8_t expected[] = {0xd3, 0xa8, 0x1d, 0x96, 0x4c, 0x9b, 0x02, 0xd7, 0x9a, 0xb0, 0x41, 0x07, 0x4c, 0x8c, 0xe2,
+                                       0xe0, 0x2e, 0x83, 0x54, 0x52, 0x45, 0xcb, 0xd4, 0x68, 0xc8, 0x43, 0x45, 0xca, 0x91, 0xfb,
+                                       0xa3, 0x7a, 0x67, 0xed, 0xe8, 0xd7, 0x5e, 0xe2, 0x33, 0xd1, 0x3e, 0xbf, 0x50, 0xc2, 0x4b,
+                                       0x86, 0x83, 0x55, 0x11, 0xbb, 0x17, 0x4f, 0xf5, 0x78, 0xb8, 0x65, 0xeb, 0x9a, 0x2b, 0x8f,
+                                       0x77, 0x08, 0xa9, 0x60, 0x17, 0x73, 0xc5, 0x07, 0xf3, 0x04, 0xc9, 0x3f, 0x67, 0x4d, 0x12,
+                                       0xa1, 0x02, 0x93, 0xc2, 0x3c, 0xd3, 0xf8, 0x59, 0x33, 0xd5, 0x01, 0xc3, 0xbb, 0xaa, 0xe6,
+                                       0x3f, 0xbb, 0x23, 0x66, 0x94, 0x26, 0x28, 0x43, 0xa5, 0xfd, 0x2f};
+
+    ptls_aead_context_t *aead = ptls_aead_new_direct(&ptls_fusion_aes128gcm, 0, key, iv);
+    uint8_t encrypted[sizeof(plaintext) + 16], decrypted[sizeof(plaintext)];
+    uint8_t seq32[4] = {0, 1, 2, 3};
+    uint8_t seq32_bad[4] = {0x89, 0xab, 0xcd, 0xef};
+
+    ptls_aead_xor_iv(aead, seq32, sizeof(seq32));
+    ptls_aead_encrypt(aead, encrypted, plaintext, sizeof(plaintext), 0, aad, sizeof(aad));
+    ok(memcmp(expected, encrypted, sizeof(plaintext)) == 0);
+    ok(memcmp(expected + sizeof(plaintext), encrypted + sizeof(plaintext), 16) == 0);
+    ok(ptls_aead_decrypt(aead, decrypted, encrypted, sizeof(encrypted), 0, aad, sizeof(aad)) == sizeof(plaintext));
+    ok(memcmp(decrypted, plaintext, sizeof(plaintext)) == 0);
+    ptls_aead_xor_iv(aead, seq32, sizeof(seq32));
+    ptls_aead_xor_iv(aead, seq32_bad, sizeof(seq32_bad));
+    ok(ptls_aead_decrypt(aead, decrypted, encrypted, sizeof(encrypted), 0, aad, sizeof(aad)) == SIZE_MAX);
+    ptls_aead_xor_iv(aead, seq32_bad, sizeof(seq32_bad));
+    ptls_aead_xor_iv(aead, seq32, sizeof(seq32));
+    ok(ptls_aead_decrypt(aead, decrypted, encrypted, sizeof(encrypted), 0, aad, sizeof(aad)) == sizeof(plaintext));
+    ok(memcmp(decrypted, plaintext, sizeof(plaintext)) == 0);
+    ptls_aead_free(aead);
+}
+
+static void test_generated(int aes256, int iv96)
 {
     ptls_cipher_context_t *rand = ptls_cipher_new(&ptls_minicrypto_aes128ctr, 1, zero);
     ptls_cipher_init(rand, zero);
     int i;
-
-    for (i = 0; i < 10000; ++i) {
+#ifdef _WINDOWS
+    const int nb_runs = 1000;
+#else
+    const int nb_runs = 10000;
+#endif
+    for (i = 0; i < nb_runs; ++i) {
         /* generate input using RNG */
-        uint8_t key[32], iv[12], aadlen, textlen;
+        uint8_t key[32], iv[12], seq32[4], aadlen, textlen;
         uint64_t seq;
         ptls_cipher_encrypt(rand, key, zero, sizeof(key));
         ptls_cipher_encrypt(rand, iv, zero, sizeof(iv));
         ptls_cipher_encrypt(rand, &aadlen, zero, sizeof(aadlen));
         ptls_cipher_encrypt(rand, &textlen, zero, sizeof(textlen));
         ptls_cipher_encrypt(rand, &seq, zero, sizeof(seq));
+        ptls_cipher_encrypt(rand, seq32, zero, sizeof(seq32));
 
         uint8_t aad[256], text[256];
 
@@ -222,6 +264,9 @@
         { /* check using fusion */
             ptls_aead_context_t *fusion =
                 ptls_aead_new_direct(aes256 ? &ptls_fusion_aes256gcm : &ptls_fusion_aes128gcm, 1, key, iv);
+            if (iv96) {
+                ptls_aead_xor_iv(fusion, seq32, sizeof(seq32));
+            }
             ptls_aead_encrypt(fusion, encrypted, text, textlen, seq, aad, aadlen);
             if (ptls_aead_decrypt(fusion, decrypted, encrypted, textlen + 16, seq, aad, aadlen) != textlen)
                 goto Fail;
@@ -235,6 +280,9 @@
         { /* check that the encrypted text can be decrypted by OpenSSL */
             ptls_aead_context_t *mc =
                 ptls_aead_new_direct(aes256 ? &ptls_minicrypto_aes256gcm : &ptls_minicrypto_aes128gcm, 0, key, iv);
+            if (iv96) {
+                ptls_aead_xor_iv(mc, seq32, sizeof(seq32));
+            }
             if (ptls_aead_decrypt(mc, decrypted, encrypted, textlen + 16, seq, aad, aadlen) != textlen)
                 goto Fail;
             if (memcmp(decrypted, text, textlen) != 0)
@@ -254,14 +302,23 @@
 
 static void test_generated_aes128(void)
 {
-    test_generated(0);
+    test_generated(0, 0);
 }
 
 static void test_generated_aes256(void)
 {
-    test_generated(1);
+    test_generated(1, 0);
 }
 
+static void test_generated_aes128_iv96(void)
+{
+    test_generated(0, 1);
+}
+
+static void test_generated_aes256_iv96(void)
+{
+    test_generated(1, 1);
+}
 int main(int argc, char **argv)
 {
     if (!ptls_fusion_is_supported_by_cpu()) {
@@ -274,8 +331,11 @@
     subtest("gcm-basic", gcm_basic);
     subtest("gcm-capacity", gcm_capacity);
     subtest("gcm-test-vectors", gcm_test_vectors);
+    subtest("gcm-iv96", gcm_iv96);
     subtest("generated-128", test_generated_aes128);
     subtest("generated-256", test_generated_aes256);
+    subtest("generated-128-iv96", test_generated_aes128_iv96);
+    subtest("generated-256-iv96", test_generated_aes256_iv96);
 
     return done_testing();
-}
+}
\ No newline at end of file
diff --git a/t/picotls.c b/t/picotls.c
index ca07a90..74e24eb 100644
--- a/t/picotls.c
+++ b/t/picotls.c
@@ -192,6 +192,46 @@
     ptls_aead_free(c);
 }
 
+static void test_aad96_ciphersuite(ptls_cipher_suite_t *cs1, ptls_cipher_suite_t *cs2)
+{
+    const char *traffic_secret = "012345678901234567890123456789012345678901234567", *src = "hello world", *aad = "my true aad";
+    ptls_aead_context_t *c;
+    char enc[256], dec[256];
+    uint8_t seq32[4] = {0xa1, 0xb2, 0xc3, 0xd4};
+    uint8_t seq32_bad[4] = {0xa2, 0xb3, 0xc4, 0xe5};
+    size_t enclen, declen;
+
+    /* encrypt */
+    c = ptls_aead_new(cs1->aead, cs1->hash, 1, traffic_secret, NULL);
+    assert(c != NULL);
+    ptls_aead_xor_iv(c, seq32, sizeof(seq32));
+    ptls_aead_encrypt_init(c, 123, aad, strlen(aad));
+    enclen = ptls_aead_encrypt_update(c, enc, src, strlen(src));
+    enclen += ptls_aead_encrypt_final(c, enc + enclen);
+    ptls_aead_free(c);
+
+    /* decrypt */
+    c = ptls_aead_new(cs2->aead, cs2->hash, 0, traffic_secret, NULL);
+    assert(c != NULL);
+    /* test first decryption */
+    ptls_aead_xor_iv(c, seq32, sizeof(seq32));
+    declen = ptls_aead_decrypt(c, dec, enc, enclen, 123, aad, strlen(aad));
+    ptls_aead_xor_iv(c, seq32, sizeof(seq32));
+    ok(declen == strlen(src));
+    ok(memcmp(src, dec, declen) == 0);
+    /* test that setting the wrong IV creates an error */
+    ptls_aead_xor_iv(c, seq32_bad, sizeof(seq32_bad));
+    declen = ptls_aead_decrypt(c, dec, enc, enclen, 123, aad, strlen(aad));
+    ptls_aead_xor_iv(c, seq32_bad, sizeof(seq32_bad));
+    ok(declen == SIZE_MAX);
+    /* test second decryption with correct IV to verify no side effect */
+    ptls_aead_xor_iv(c, seq32, sizeof(seq32));
+    declen = ptls_aead_decrypt(c, dec, enc, enclen, 123, aad, strlen(aad));
+    ok(declen == strlen(src));
+    ok(memcmp(src, dec, declen) == 0);
+    ptls_aead_free(c);
+}
+
 static void test_ecb(ptls_cipher_algorithm_t *algo, const void *expected, size_t expected_len)
 {
     static const uint8_t key[] = {0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15,
@@ -285,31 +325,34 @@
 static void test_aes128gcm(void)
 {
     ptls_cipher_suite_t *cs = find_cipher(ctx, PTLS_CIPHER_SUITE_AES_128_GCM_SHA256),
-                        *cs_peer = find_cipher(ctx, PTLS_CIPHER_SUITE_AES_128_GCM_SHA256);
+                        *cs_peer = find_cipher(ctx_peer, PTLS_CIPHER_SUITE_AES_128_GCM_SHA256);
 
     test_ciphersuite(cs, cs_peer);
     test_aad_ciphersuite(cs, cs_peer);
+    test_aad96_ciphersuite(cs, cs_peer);
 }
 
 static void test_aes256gcm(void)
 {
     ptls_cipher_suite_t *cs = find_cipher(ctx, PTLS_CIPHER_SUITE_AES_256_GCM_SHA384),
-                        *cs_peer = find_cipher(ctx, PTLS_CIPHER_SUITE_AES_256_GCM_SHA384);
+                        *cs_peer = find_cipher(ctx_peer, PTLS_CIPHER_SUITE_AES_256_GCM_SHA384);
 
     if (cs != NULL && cs_peer != NULL) {
         test_ciphersuite(cs, cs_peer);
         test_aad_ciphersuite(cs, cs_peer);
+        test_aad96_ciphersuite(cs, cs_peer);
     }
 }
 
 static void test_chacha20poly1305(void)
 {
     ptls_cipher_suite_t *cs = find_cipher(ctx, PTLS_CIPHER_SUITE_CHACHA20_POLY1305_SHA256),
-                        *cs_peer = find_cipher(ctx, PTLS_CIPHER_SUITE_CHACHA20_POLY1305_SHA256);
+                        *cs_peer = find_cipher(ctx_peer, PTLS_CIPHER_SUITE_CHACHA20_POLY1305_SHA256);
 
     if (cs != NULL && cs_peer != NULL) {
         test_ciphersuite(cs, cs_peer);
         test_aad_ciphersuite(cs, cs_peer);
+        test_aad96_ciphersuite(cs, cs_peer);
     }
 }