Merge pull request #468 from h2o/kazuho/setget-iv

provide API for fully rewriting IV
diff --git a/include/picotls.h b/include/picotls.h
index 0a42d52..420861f 100644
--- a/include/picotls.h
+++ b/include/picotls.h
@@ -398,7 +398,8 @@
     const struct st_ptls_aead_algorithm_t *algo;
     /* field above this line must not be altered by the crypto binding */
     void (*dispose_crypto)(struct st_ptls_aead_context_t *ctx);
-    void (*do_xor_iv)(struct st_ptls_aead_context_t *ctx, const void *bytes, size_t len);
+    void (*do_get_iv)(struct st_ptls_aead_context_t *ctx, void *iv);
+    void (*do_set_iv)(struct st_ptls_aead_context_t *ctx, const void *iv);
     void (*do_encrypt_init)(struct st_ptls_aead_context_t *ctx, uint64_t seq, const void *aad, size_t aadlen);
     size_t (*do_encrypt_update)(struct st_ptls_aead_context_t *ctx, void *output, const void *input, size_t inlen);
     size_t (*do_encrypt_final)(struct st_ptls_aead_context_t *ctx, void *output);
@@ -1609,7 +1610,9 @@
  * Permutes the static IV by applying given bytes using bit-wise XOR. This API can be used for supplying nonces longer than 64-
  * bits.
  */
-static void ptls_aead_xor_iv(ptls_aead_context_t *ctx, const void *bytes, size_t len);
+void ptls_aead_xor_iv(ptls_aead_context_t *ctx, const void *bytes, size_t len);
+static void ptls_aead_get_iv(ptls_aead_context_t *ctx, void *iv);
+static void ptls_aead_set_iv(ptls_aead_context_t *ctx, const void *iv);
 /**
  * Encrypts one AEAD block, given input and output vectors.
  */
@@ -1817,9 +1820,14 @@
     ctx->do_transform(ctx, output, input, len);
 }
 
-inline void ptls_aead_xor_iv(ptls_aead_context_t *ctx, const void *bytes, size_t len)
+inline void ptls_aead_get_iv(ptls_aead_context_t *ctx, void *iv)
 {
-    ctx->do_xor_iv(ctx, bytes, len);
+    ctx->do_get_iv(ctx, iv);
+}
+
+inline void ptls_aead_set_iv(ptls_aead_context_t *ctx, const void *iv)
+{
+    ctx->do_set_iv(ctx, iv);
 }
 
 inline size_t ptls_aead_encrypt(ptls_aead_context_t *ctx, void *output, const void *input, size_t inlen, uint64_t seq,
diff --git a/lib/cifra/aes-common.h b/lib/cifra/aes-common.h
index 234e647..bd0b58d 100644
--- a/lib/cifra/aes-common.h
+++ b/lib/cifra/aes-common.h
@@ -154,13 +154,18 @@
     return tag_offset;
 }
 
-static inline void aesgcm_xor_iv(ptls_aead_context_t *_ctx, const void *_bytes, size_t len)
+static inline void aesgcm_get_iv(ptls_aead_context_t *_ctx, void *iv)
 {
     struct aesgcm_context_t *ctx = (struct aesgcm_context_t *)_ctx;
-    const uint8_t *bytes = _bytes;
 
-    for (size_t i = 0; i < len; ++i)
-        ctx->static_iv[i] ^= bytes[i];
+    memcpy(iv, ctx->static_iv, sizeof(ctx->static_iv));
+}
+
+static inline void aesgcm_set_iv(ptls_aead_context_t *_ctx, const void *iv)
+{
+    struct aesgcm_context_t *ctx = (struct aesgcm_context_t *)_ctx;
+
+    memcpy(ctx->static_iv, iv, sizeof(ctx->static_iv));
 }
 
 static inline int aead_aesgcm_setup_crypto(ptls_aead_context_t *_ctx, int is_enc, const void *key, const void *iv)
@@ -168,7 +173,8 @@
     struct aesgcm_context_t *ctx = (struct aesgcm_context_t *)_ctx;
 
     ctx->super.dispose_crypto = aesgcm_dispose_crypto;
-    ctx->super.do_xor_iv = aesgcm_xor_iv;
+    ctx->super.do_get_iv = aesgcm_get_iv;
+    ctx->super.do_set_iv = aesgcm_set_iv;
     if (is_enc) {
         ctx->super.do_encrypt_init = aesgcm_encrypt_init;
         ctx->super.do_encrypt_update = aesgcm_encrypt_update;
diff --git a/lib/cifra/chacha20.c b/lib/cifra/chacha20.c
index 3332877..95a9e56 100644
--- a/lib/cifra/chacha20.c
+++ b/lib/cifra/chacha20.c
@@ -179,13 +179,18 @@
     return ret;
 }
 
-static void chacha20poly1305_xor_iv(ptls_aead_context_t *_ctx, const void *_bytes, size_t len)
+static void chacha20poly1305_get_iv(ptls_aead_context_t *_ctx, void *iv)
 {
     struct chacha20poly1305_context_t *ctx = (struct chacha20poly1305_context_t *)_ctx;
-    const uint8_t *bytes = _bytes;
 
-    for (size_t i = 0; i < len; ++i)
-        ctx->static_iv[i] ^= bytes[i];
+    memcpy(iv, ctx->static_iv, sizeof(ctx->static_iv));
+}
+
+static void chacha20poly1305_set_iv(ptls_aead_context_t *_ctx, const void *iv)
+{
+    struct chacha20poly1305_context_t *ctx = (struct chacha20poly1305_context_t *)_ctx;
+
+    memcpy(ctx->static_iv, iv, sizeof(ctx->static_iv));
 }
 
 static int aead_chacha20poly1305_setup_crypto(ptls_aead_context_t *_ctx, int is_enc, const void *key, const void *iv)
@@ -193,7 +198,8 @@
     struct chacha20poly1305_context_t *ctx = (struct chacha20poly1305_context_t *)_ctx;
 
     ctx->super.dispose_crypto = chacha20poly1305_dispose_crypto;
-    ctx->super.do_xor_iv = chacha20poly1305_xor_iv;
+    ctx->super.do_get_iv = chacha20poly1305_get_iv;
+    ctx->super.do_set_iv = chacha20poly1305_set_iv;
     if (is_enc) {
         ctx->super.do_encrypt_init = chacha20poly1305_init;
         ctx->super.do_encrypt_update = chacha20poly1305_encrypt_update;
diff --git a/lib/fusion.c b/lib/fusion.c
index b9f99da..35fe579 100644
--- a/lib/fusion.c
+++ b/lib/fusion.c
@@ -1165,12 +1165,20 @@
     return enclen;
 }
 
-static inline void aesgcm_xor_iv(ptls_aead_context_t *_ctx, const void *_bytes, size_t len)
+static inline void aesgcm_get_iv(ptls_aead_context_t *_ctx, void *iv)
 {
     struct aesgcm_context *ctx = (struct aesgcm_context *)_ctx;
-    __m128i xor_mask = loadn128(_bytes, len);
-    xor_mask = _mm_shuffle_epi8(xor_mask, byteswap128);
-    ctx->static_iv = _mm_xor_si128(ctx->static_iv, xor_mask);
+
+    __m128i m128 = _mm_shuffle_epi8(ctx->static_iv, byteswap128);
+    storen128(iv, PTLS_AESGCM_IV_SIZE, m128);
+}
+
+static inline void aesgcm_set_iv(ptls_aead_context_t *_ctx, const void *iv)
+{
+    struct aesgcm_context *ctx = (struct aesgcm_context *)_ctx;
+
+    ctx->static_iv = loadn128(iv, PTLS_AESGCM_IV_SIZE);
+    ctx->static_iv = _mm_shuffle_epi8(ctx->static_iv, byteswap128);
 }
 
 static int aesgcm_setup(ptls_aead_context_t *_ctx, int is_enc, const void *key, const void *iv, size_t key_size)
@@ -1183,7 +1191,8 @@
         return 0;
 
     ctx->super.dispose_crypto = aesgcm_dispose_crypto;
-    ctx->super.do_xor_iv = aesgcm_xor_iv;
+    ctx->super.do_get_iv = aesgcm_get_iv;
+    ctx->super.do_set_iv = aesgcm_set_iv;
     ctx->super.do_encrypt_init = aead_do_encrypt_init;
     ctx->super.do_encrypt_update = aead_do_encrypt_update;
     ctx->super.do_encrypt_final = aead_do_encrypt_final;
@@ -2108,7 +2117,8 @@
         return 0;
 
     ctx->super.dispose_crypto = aesgcm_dispose_crypto;
-    ctx->super.do_xor_iv = aesgcm_xor_iv;
+    ctx->super.do_get_iv = aesgcm_get_iv;
+    ctx->super.do_set_iv = aesgcm_set_iv;
     ctx->super.do_encrypt_init = NULL;
     ctx->super.do_encrypt_update = NULL;
     ctx->super.do_encrypt_final = NULL;
diff --git a/lib/openssl.c b/lib/openssl.c
index 3278f3d..b710e88 100644
--- a/lib/openssl.c
+++ b/lib/openssl.c
@@ -1006,13 +1006,18 @@
         EVP_CIPHER_CTX_free(ctx->evp_ctx);
 }
 
-static void aead_xor_iv(ptls_aead_context_t *_ctx, const void *_bytes, size_t len)
+static void aead_get_iv(ptls_aead_context_t *_ctx, void *iv)
 {
     struct aead_crypto_context_t *ctx = (struct aead_crypto_context_t *)_ctx;
-    const uint8_t *bytes = _bytes;
 
-    for (size_t i = 0; i < len; ++i)
-        ctx->static_iv[i] ^= bytes[i];
+    memcpy(iv, ctx->static_iv, ctx->super.algo->iv_size);
+}
+
+static void aead_set_iv(ptls_aead_context_t *_ctx, const void *iv)
+{
+    struct aead_crypto_context_t *ctx = (struct aead_crypto_context_t *)_ctx;
+
+    memcpy(ctx->static_iv, iv, ctx->super.algo->iv_size);
 }
 
 static void aead_do_encrypt_init(ptls_aead_context_t *_ctx, uint64_t seq, const void *aad, size_t aadlen)
@@ -1096,7 +1101,8 @@
     int ret;
 
     ctx->super.dispose_crypto = aead_dispose_crypto;
-    ctx->super.do_xor_iv = aead_xor_iv;
+    ctx->super.do_get_iv = aead_get_iv;
+    ctx->super.do_set_iv = aead_set_iv;
     if (is_enc) {
         ctx->super.do_encrypt_init = aead_do_encrypt_init;
         ctx->super.do_encrypt_update = aead_do_encrypt_update;
diff --git a/lib/picotls.c b/lib/picotls.c
index 4878d68..d1064e3 100644
--- a/lib/picotls.c
+++ b/lib/picotls.c
@@ -6204,6 +6204,17 @@
     free(ctx);
 }
 
+void ptls_aead_xor_iv(ptls_aead_context_t *ctx, const void *_bytes, size_t len)
+{
+    const uint8_t *bytes = _bytes;
+    uint8_t iv[PTLS_MAX_IV_SIZE];
+
+    ptls_aead_get_iv(ctx, iv);
+    for (size_t i = 0; i < len; ++i)
+        iv[i] ^= bytes[i];
+    ptls_aead_set_iv(ctx, iv);
+}
+
 void ptls_aead__build_iv(ptls_aead_algorithm_t *algo, uint8_t *iv, const uint8_t *static_iv, uint64_t seq)
 {
     size_t iv_size = algo->iv_size, i;
diff --git a/lib/ptlsbcrypt.c b/lib/ptlsbcrypt.c
index 2ffac7b..89da77d 100644
--- a/lib/ptlsbcrypt.c
+++ b/lib/ptlsbcrypt.c
@@ -476,13 +476,18 @@
     }
 }
 
-static void ptls_bcrypt_aead_xor_iv(ptls_aead_context_t *_ctx, const void *_bytes, size_t len)
+static void ptls_bcrypt_aead_get_iv(ptls_aead_context_t *_ctx, void *iv)
 {
     struct ptls_bcrypt_aead_context_t *ctx = (struct ptls_bcrypt_aead_context_t *)_ctx;
-    const uint8_t *bytes = _bytes;
 
-    for (size_t i = 0; i < len; ++i)
-        ctx->bctx.iv[i] ^= bytes[i];
+    memcpy(iv, ctx->bctx.iv, ctx->super.algo->iv_size);
+}
+
+static void ptls_bcrypt_aead_set_iv(ptls_aead_context_t *_ctx, const void *iv)
+{
+    struct ptls_bcrypt_aead_context_t *ctx = (struct ptls_bcrypt_aead_context_t *)_ctx;
+
+    memcpy(ctx->bctx.iv, iv, ctx->super.algo->iv_size);
 }
 
 static int ptls_bcrypt_aead_setup_crypto(ptls_aead_context_t *_ctx, int is_enc, const void *key, const void *iv,
@@ -539,7 +544,8 @@
         memcpy(ctx->bctx.iv_static, iv, ctx->super.algo->iv_size);
         if (is_enc) {
             ctx->super.dispose_crypto = ptls_bcrypt_aead_dispose_crypto;
-            ctx->super.do_xor_iv = ptls_bcrypt_aead_xor_iv;
+            ctx->super.do_get_iv = ptls_bcrypt_aead_get_iv;
+            ctx->super.do_set_iv = ptls_bcrypt_aead_set_iv;
             ctx->super.do_decrypt = NULL;
             ctx->super.do_encrypt_init = ptls_bcrypt_aead_do_encrypt_init;
             ctx->super.do_encrypt_update = ptls_bcrypt_aead_do_encrypt_update;