Flatten EVP_AEAD_CTX

An EVP_AEAD_CTX used to be a small struct that contained a pointer to
an AEAD-specific context. That involved heap allocating the
AEAD-specific context, which was a problem for users who wanted to setup
and discard these objects quickly.

Instead this change makes EVP_AEAD_CTX large enough to contain the
AEAD-specific context inside itself. The dominant AEAD is AES-GCM, and
that's also the largest. So, in practice, this shouldn't waste too much
memory.

Change-Id: I795cb37afae9df1424f882adaf514a222e040c80
Reviewed-on: https://boringssl-review.googlesource.com/c/32506
Commit-Queue: Adam Langley <agl@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/crypto/cipher_extra/e_aesccm.c b/crypto/cipher_extra/e_aesccm.c
index 87f16dc..37a9add 100644
--- a/crypto/cipher_extra/e_aesccm.c
+++ b/crypto/cipher_extra/e_aesccm.c
@@ -33,6 +33,15 @@
   CCM128_CONTEXT ccm;
 };
 
+OPENSSL_COMPILE_ASSERT(sizeof(((EVP_AEAD_CTX *)NULL)->state) >=
+                           sizeof(struct aead_aes_ccm_ctx),
+                       AEAD_state_too_small);
+#if defined(__GNUC__) || defined(__clang__)
+OPENSSL_COMPILE_ASSERT(alignof(union evp_aead_ctx_st_state) >=
+                           alignof(struct aead_aes_ccm_ctx),
+                       AEAD_state_insufficient_alignment);
+#endif
+
 static int aead_aes_ccm_init(EVP_AEAD_CTX *ctx, const uint8_t *key,
                              size_t key_len, size_t tag_len, unsigned M,
                              unsigned L) {
@@ -54,36 +63,28 @@
     return 0;
   }
 
-  struct aead_aes_ccm_ctx *ccm_ctx =
-      OPENSSL_malloc(sizeof(struct aead_aes_ccm_ctx));
-  if (ccm_ctx == NULL) {
-    OPENSSL_PUT_ERROR(CIPHER, ERR_R_MALLOC_FAILURE);
-    return 0;
-  }
+  struct aead_aes_ccm_ctx *ccm_ctx = (struct aead_aes_ccm_ctx *)&ctx->state;
 
   block128_f block;
   ctr128_f ctr = aes_ctr_set_key(&ccm_ctx->ks.ks, NULL, &block, key, key_len);
   ctx->tag_len = tag_len;
   if (!CRYPTO_ccm128_init(&ccm_ctx->ccm, &ccm_ctx->ks.ks, block, ctr, M, L)) {
     OPENSSL_PUT_ERROR(CIPHER, ERR_R_INTERNAL_ERROR);
-    OPENSSL_free(ccm_ctx);
     return 0;
   }
 
-  ctx->aead_state = ccm_ctx;
   return 1;
 }
 
-static void aead_aes_ccm_cleanup(EVP_AEAD_CTX *ctx) {
-  OPENSSL_free(ctx->aead_state);
-}
+static void aead_aes_ccm_cleanup(EVP_AEAD_CTX *ctx) {}
 
 static int aead_aes_ccm_seal_scatter(
     const EVP_AEAD_CTX *ctx, uint8_t *out, uint8_t *out_tag,
     size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
     size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
     size_t extra_in_len, const uint8_t *ad, size_t ad_len) {
-  const struct aead_aes_ccm_ctx *ccm_ctx = ctx->aead_state;
+  const struct aead_aes_ccm_ctx *ccm_ctx =
+      (struct aead_aes_ccm_ctx *)&ctx->state;
 
   if (in_len > CRYPTO_ccm128_max_input(&ccm_ctx->ccm)) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_TOO_LARGE);
@@ -116,7 +117,8 @@
                                     const uint8_t *in, size_t in_len,
                                     const uint8_t *in_tag, size_t in_tag_len,
                                     const uint8_t *ad, size_t ad_len) {
-  const struct aead_aes_ccm_ctx *ccm_ctx = ctx->aead_state;
+  const struct aead_aes_ccm_ctx *ccm_ctx =
+      (struct aead_aes_ccm_ctx *)&ctx->state;
 
   if (in_len > CRYPTO_ccm128_max_input(&ccm_ctx->ccm)) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_TOO_LARGE);
diff --git a/crypto/cipher_extra/e_aesctrhmac.c b/crypto/cipher_extra/e_aesctrhmac.c
index 3a0de9b..54a50ec 100644
--- a/crypto/cipher_extra/e_aesctrhmac.c
+++ b/crypto/cipher_extra/e_aesctrhmac.c
@@ -35,6 +35,15 @@
   SHA256_CTX outer_init_state;
 };
 
+OPENSSL_COMPILE_ASSERT(sizeof(((EVP_AEAD_CTX *)NULL)->state) >=
+                           sizeof(struct aead_aes_ctr_hmac_sha256_ctx),
+                       AEAD_state_too_small);
+#if defined(__GNUC__) || defined(__clang__)
+OPENSSL_COMPILE_ASSERT(alignof(union evp_aead_ctx_st_state) >=
+                           alignof(struct aead_aes_ctr_hmac_sha256_ctx),
+                       AEAD_state_insufficient_alignment);
+#endif
+
 static void hmac_init(SHA256_CTX *out_inner, SHA256_CTX *out_outer,
                       const uint8_t hmac_key[32]) {
   static const size_t hmac_key_len = 32;
@@ -61,7 +70,8 @@
 
 static int aead_aes_ctr_hmac_sha256_init(EVP_AEAD_CTX *ctx, const uint8_t *key,
                                          size_t key_len, size_t tag_len) {
-  struct aead_aes_ctr_hmac_sha256_ctx *aes_ctx;
+  struct aead_aes_ctr_hmac_sha256_ctx *aes_ctx =
+      (struct aead_aes_ctr_hmac_sha256_ctx *)&ctx->state;
   static const size_t hmac_key_len = 32;
 
   if (key_len < hmac_key_len) {
@@ -84,26 +94,16 @@
     return 0;
   }
 
-  aes_ctx = OPENSSL_malloc(sizeof(struct aead_aes_ctr_hmac_sha256_ctx));
-  if (aes_ctx == NULL) {
-    OPENSSL_PUT_ERROR(CIPHER, ERR_R_MALLOC_FAILURE);
-    return 0;
-  }
-
   aes_ctx->ctr =
       aes_ctr_set_key(&aes_ctx->ks.ks, NULL, &aes_ctx->block, key, aes_key_len);
   ctx->tag_len = tag_len;
   hmac_init(&aes_ctx->inner_init_state, &aes_ctx->outer_init_state,
             key + aes_key_len);
 
-  ctx->aead_state = aes_ctx;
-
   return 1;
 }
 
-static void aead_aes_ctr_hmac_sha256_cleanup(EVP_AEAD_CTX *ctx) {
-  OPENSSL_free(ctx->aead_state);
-}
+static void aead_aes_ctr_hmac_sha256_cleanup(EVP_AEAD_CTX *ctx) {}
 
 static void hmac_update_uint64(SHA256_CTX *sha256, uint64_t value) {
   unsigned i;
@@ -178,7 +178,8 @@
     size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
     size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
     size_t extra_in_len, const uint8_t *ad, size_t ad_len) {
-  const struct aead_aes_ctr_hmac_sha256_ctx *aes_ctx = ctx->aead_state;
+  const struct aead_aes_ctr_hmac_sha256_ctx *aes_ctx =
+      (struct aead_aes_ctr_hmac_sha256_ctx *) &ctx->state;
   const uint64_t in_len_64 = in_len;
 
   if (in_len_64 >= (UINT64_C(1) << 32) * AES_BLOCK_SIZE) {
@@ -212,7 +213,8 @@
     const EVP_AEAD_CTX *ctx, uint8_t *out, const uint8_t *nonce,
     size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *in_tag,
     size_t in_tag_len, const uint8_t *ad, size_t ad_len) {
-  const struct aead_aes_ctr_hmac_sha256_ctx *aes_ctx = ctx->aead_state;
+  const struct aead_aes_ctr_hmac_sha256_ctx *aes_ctx =
+      (struct aead_aes_ctr_hmac_sha256_ctx *) &ctx->state;
 
   if (in_tag_len != ctx->tag_len) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_BAD_DECRYPT);
diff --git a/crypto/cipher_extra/e_aesgcmsiv.c b/crypto/cipher_extra/e_aesgcmsiv.c
index 9de2300..fe43633 100644
--- a/crypto/cipher_extra/e_aesgcmsiv.c
+++ b/crypto/cipher_extra/e_aesgcmsiv.c
@@ -34,12 +34,27 @@
 struct aead_aes_gcm_siv_asm_ctx {
   alignas(16) uint8_t key[16*15];
   int is_128_bit;
-  // ptr contains the original pointer from |OPENSSL_malloc|, which may only be
-  // 8-byte aligned. When freeing this structure, actually call |OPENSSL_free|
-  // on this pointer.
-  void *ptr;
 };
 
+// The assembly code assumes 8-byte alignment of the EVP_AEAD_CTX's state, and
+// aligns to 16 bytes itself.
+OPENSSL_COMPILE_ASSERT(sizeof(((EVP_AEAD_CTX *)NULL)->state) + 8 >=
+                           sizeof(struct aead_aes_gcm_siv_asm_ctx),
+                       AEAD_state_too_small);
+#if defined(__GNUC__) || defined(__clang__)
+OPENSSL_COMPILE_ASSERT(alignof(union evp_aead_ctx_st_state) >= 8,
+                       AEAD_state_insufficient_alignment);
+#endif
+
+// asm_ctx_from_ctx returns a 16-byte aligned context pointer from |ctx|.
+static struct aead_aes_gcm_siv_asm_ctx *asm_ctx_from_ctx(
+    const EVP_AEAD_CTX *ctx) {
+  // ctx->state must already be 8-byte aligned. Thus, at most, we may need to
+  // add eight to align it to 16 bytes.
+  const uintptr_t offset = ((uintptr_t)&ctx->state) & 8;
+  return (struct aead_aes_gcm_siv_asm_ctx *)(&ctx->state.opaque[offset]);
+}
+
 // aes128gcmsiv_aes_ks writes an AES-128 key schedule for |key| to
 // |out_expanded_key|.
 extern void aes128gcmsiv_aes_ks(
@@ -68,18 +83,8 @@
     return 0;
   }
 
-  char *ptr = OPENSSL_malloc(sizeof(struct aead_aes_gcm_siv_asm_ctx) + 8);
-  if (ptr == NULL) {
-    return 0;
-  }
-  assert((((uintptr_t)ptr) & 7) == 0);
-
-  // gcm_siv_ctx needs to be 16-byte aligned in a cross-platform way.
-  struct aead_aes_gcm_siv_asm_ctx *gcm_siv_ctx =
-      (struct aead_aes_gcm_siv_asm_ctx *)(ptr + (((uintptr_t)ptr) & 8));
-
+  struct aead_aes_gcm_siv_asm_ctx *gcm_siv_ctx = asm_ctx_from_ctx(ctx);
   assert((((uintptr_t)gcm_siv_ctx) & 15) == 0);
-  gcm_siv_ctx->ptr = ptr;
 
   if (key_bits == 128) {
     aes128gcmsiv_aes_ks(key, &gcm_siv_ctx->key[0]);
@@ -88,16 +93,13 @@
     aes256gcmsiv_aes_ks(key, &gcm_siv_ctx->key[0]);
     gcm_siv_ctx->is_128_bit = 0;
   }
-  ctx->aead_state = gcm_siv_ctx;
+
   ctx->tag_len = tag_len;
 
   return 1;
 }
 
-static void aead_aes_gcm_siv_asm_cleanup(EVP_AEAD_CTX *ctx) {
-  const struct aead_aes_gcm_siv_asm_ctx *gcm_siv_ctx = ctx->aead_state;
-  OPENSSL_free(gcm_siv_ctx->ptr);
-}
+static void aead_aes_gcm_siv_asm_cleanup(EVP_AEAD_CTX *ctx) {}
 
 // aesgcmsiv_polyval_horner updates the POLYVAL value in |in_out_poly| to
 // include a number (|in_blocks|) of 16-byte blocks of data from |in|, given
@@ -337,7 +339,7 @@
     size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
     size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
     size_t extra_in_len, const uint8_t *ad, size_t ad_len) {
-  const struct aead_aes_gcm_siv_asm_ctx *gcm_siv_ctx = ctx->aead_state;
+  const struct aead_aes_gcm_siv_asm_ctx *gcm_siv_ctx = asm_ctx_from_ctx(ctx);
   const uint64_t in_len_64 = in_len;
   const uint64_t ad_len_64 = ad_len;
 
@@ -420,7 +422,7 @@
     return 0;
   }
 
-  const struct aead_aes_gcm_siv_asm_ctx *gcm_siv_ctx = ctx->aead_state;
+  const struct aead_aes_gcm_siv_asm_ctx *gcm_siv_ctx = asm_ctx_from_ctx(ctx);
   const size_t plaintext_len = in_len - EVP_AEAD_AES_GCM_SIV_TAG_LEN;
   const uint8_t *const given_tag = in + plaintext_len;
 
@@ -558,6 +560,15 @@
   unsigned is_256:1;
 };
 
+OPENSSL_COMPILE_ASSERT(sizeof(((EVP_AEAD_CTX *)NULL)->state) >=
+                       sizeof(struct aead_aes_gcm_siv_ctx),
+                       AEAD_state_too_small);
+#if defined(__GNUC__) || defined(__clang__)
+OPENSSL_COMPILE_ASSERT(alignof(union evp_aead_ctx_st_state) >=
+                       alignof(struct aead_aes_gcm_siv_ctx),
+                       AEAD_state_insufficient_alignment);
+#endif
+
 static int aead_aes_gcm_siv_init(EVP_AEAD_CTX *ctx, const uint8_t *key,
                                  size_t key_len, size_t tag_len) {
   const size_t key_bits = key_len * 8;
@@ -576,24 +587,18 @@
   }
 
   struct aead_aes_gcm_siv_ctx *gcm_siv_ctx =
-      OPENSSL_malloc(sizeof(struct aead_aes_gcm_siv_ctx));
-  if (gcm_siv_ctx == NULL) {
-    return 0;
-  }
+      (struct aead_aes_gcm_siv_ctx *)&ctx->state;
   OPENSSL_memset(gcm_siv_ctx, 0, sizeof(struct aead_aes_gcm_siv_ctx));
 
   aes_ctr_set_key(&gcm_siv_ctx->ks.ks, NULL, &gcm_siv_ctx->kgk_block, key,
                   key_len);
   gcm_siv_ctx->is_256 = (key_len == 32);
-  ctx->aead_state = gcm_siv_ctx;
   ctx->tag_len = tag_len;
 
   return 1;
 }
 
-static void aead_aes_gcm_siv_cleanup(EVP_AEAD_CTX *ctx) {
-  OPENSSL_free(ctx->aead_state);
-}
+static void aead_aes_gcm_siv_cleanup(EVP_AEAD_CTX *ctx) {}
 
 // gcm_siv_crypt encrypts (or decrypts—it's the same thing) |in_len| bytes from
 // |in| to |out|, using the block function |enc_block| with |key| in counter
@@ -718,7 +723,8 @@
     size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
     size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
     size_t extra_in_len, const uint8_t *ad, size_t ad_len) {
-  const struct aead_aes_gcm_siv_ctx *gcm_siv_ctx = ctx->aead_state;
+  const struct aead_aes_gcm_siv_ctx *gcm_siv_ctx =
+      (struct aead_aes_gcm_siv_ctx *)&ctx->state;
   const uint64_t in_len_64 = in_len;
   const uint64_t ad_len_64 = ad_len;
 
@@ -778,7 +784,8 @@
     return 0;
   }
 
-  const struct aead_aes_gcm_siv_ctx *gcm_siv_ctx = ctx->aead_state;
+  const struct aead_aes_gcm_siv_ctx *gcm_siv_ctx =
+      (struct aead_aes_gcm_siv_ctx *)&ctx->state;
 
   struct gcm_siv_record_keys keys;
   gcm_siv_keys(gcm_siv_ctx, &keys, nonce);
diff --git a/crypto/cipher_extra/e_chacha20poly1305.c b/crypto/cipher_extra/e_chacha20poly1305.c
index e632010..5aee4ae 100644
--- a/crypto/cipher_extra/e_chacha20poly1305.c
+++ b/crypto/cipher_extra/e_chacha20poly1305.c
@@ -35,6 +35,15 @@
   uint8_t key[32];
 };
 
+OPENSSL_COMPILE_ASSERT(sizeof(((EVP_AEAD_CTX *)NULL)->state) >=
+                       sizeof(struct aead_chacha20_poly1305_ctx),
+                       AEAD_state_too_small);
+#if defined(__GNUC__) || defined(__clang__)
+OPENSSL_COMPILE_ASSERT(alignof(union evp_aead_ctx_st_state) >=
+                           alignof(struct aead_chacha20_poly1305_ctx),
+                       AEAD_state_insufficient_alignment);
+#endif
+
 // For convenience (the x86_64 calling convention allows only six parameters in
 // registers), the final parameter for the assembly functions is both an input
 // and output parameter.
@@ -109,7 +118,8 @@
 
 static int aead_chacha20_poly1305_init(EVP_AEAD_CTX *ctx, const uint8_t *key,
                                        size_t key_len, size_t tag_len) {
-  struct aead_chacha20_poly1305_ctx *c20_ctx;
+  struct aead_chacha20_poly1305_ctx *c20_ctx =
+      (struct aead_chacha20_poly1305_ctx *)&ctx->state;
 
   if (tag_len == 0) {
     tag_len = POLY1305_TAG_LEN;
@@ -124,21 +134,13 @@
     return 0;  // internal error - EVP_AEAD_CTX_init should catch this.
   }
 
-  c20_ctx = OPENSSL_malloc(sizeof(struct aead_chacha20_poly1305_ctx));
-  if (c20_ctx == NULL) {
-    return 0;
-  }
-
   OPENSSL_memcpy(c20_ctx->key, key, key_len);
-  ctx->aead_state = c20_ctx;
   ctx->tag_len = tag_len;
 
   return 1;
 }
 
-static void aead_chacha20_poly1305_cleanup(EVP_AEAD_CTX *ctx) {
-  OPENSSL_free(ctx->aead_state);
-}
+static void aead_chacha20_poly1305_cleanup(EVP_AEAD_CTX *ctx) {}
 
 static void poly1305_update_length(poly1305_state *poly1305, size_t data_len) {
   uint8_t length_bytes[8];
@@ -260,7 +262,8 @@
     size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
     size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
     size_t extra_in_len, const uint8_t *ad, size_t ad_len) {
-  const struct aead_chacha20_poly1305_ctx *c20_ctx = ctx->aead_state;
+  const struct aead_chacha20_poly1305_ctx *c20_ctx =
+      (struct aead_chacha20_poly1305_ctx *)&ctx->state;
 
   return chacha20_poly1305_seal_scatter(
       c20_ctx->key, out, out_tag, out_tag_len, max_out_tag_len, nonce,
@@ -272,7 +275,8 @@
     size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
     size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
     size_t extra_in_len, const uint8_t *ad, size_t ad_len) {
-  const struct aead_chacha20_poly1305_ctx *c20_ctx = ctx->aead_state;
+  const struct aead_chacha20_poly1305_ctx *c20_ctx =
+      (struct aead_chacha20_poly1305_ctx *)&ctx->state;
 
   if (nonce_len != 24) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_UNSUPPORTED_NONCE_SIZE);
@@ -340,7 +344,8 @@
     const EVP_AEAD_CTX *ctx, uint8_t *out, const uint8_t *nonce,
     size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *in_tag,
     size_t in_tag_len, const uint8_t *ad, size_t ad_len) {
-  const struct aead_chacha20_poly1305_ctx *c20_ctx = ctx->aead_state;
+  const struct aead_chacha20_poly1305_ctx *c20_ctx =
+      (struct aead_chacha20_poly1305_ctx *)&ctx->state;
 
   return chacha20_poly1305_open_gather(c20_ctx->key, out, nonce, nonce_len, in,
                                        in_len, in_tag, in_tag_len, ad, ad_len,
@@ -351,7 +356,8 @@
     const EVP_AEAD_CTX *ctx, uint8_t *out, const uint8_t *nonce,
     size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *in_tag,
     size_t in_tag_len, const uint8_t *ad, size_t ad_len) {
-  const struct aead_chacha20_poly1305_ctx *c20_ctx = ctx->aead_state;
+  const struct aead_chacha20_poly1305_ctx *c20_ctx =
+      (struct aead_chacha20_poly1305_ctx *)&ctx->state;
 
   if (nonce_len != 24) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_UNSUPPORTED_NONCE_SIZE);
diff --git a/crypto/cipher_extra/e_tls.c b/crypto/cipher_extra/e_tls.c
index bba22be..1f1fc3a 100644
--- a/crypto/cipher_extra/e_tls.c
+++ b/crypto/cipher_extra/e_tls.c
@@ -44,12 +44,19 @@
 
 OPENSSL_COMPILE_ASSERT(EVP_MAX_MD_SIZE < 256, mac_key_len_fits_in_uint8_t);
 
+OPENSSL_COMPILE_ASSERT(sizeof(((EVP_AEAD_CTX *)NULL)->state) >=
+                           sizeof(AEAD_TLS_CTX),
+                       AEAD_state_too_small);
+#if defined(__GNUC__) || defined(__clang__)
+OPENSSL_COMPILE_ASSERT(alignof(union evp_aead_ctx_st_state) >=
+                           alignof(AEAD_TLS_CTX),
+                       AEAD_state_insufficient_alignment);
+#endif
+
 static void aead_tls_cleanup(EVP_AEAD_CTX *ctx) {
-  AEAD_TLS_CTX *tls_ctx = (AEAD_TLS_CTX *)ctx->aead_state;
+  AEAD_TLS_CTX *tls_ctx = (AEAD_TLS_CTX *)&ctx->state;
   EVP_CIPHER_CTX_cleanup(&tls_ctx->cipher_ctx);
   HMAC_CTX_cleanup(&tls_ctx->hmac_ctx);
-  OPENSSL_free(tls_ctx);
-  ctx->aead_state = NULL;
 }
 
 static int aead_tls_init(EVP_AEAD_CTX *ctx, const uint8_t *key, size_t key_len,
@@ -72,11 +79,7 @@
   assert(mac_key_len + enc_key_len +
          (implicit_iv ? EVP_CIPHER_iv_length(cipher) : 0) == key_len);
 
-  AEAD_TLS_CTX *tls_ctx = OPENSSL_malloc(sizeof(AEAD_TLS_CTX));
-  if (tls_ctx == NULL) {
-    OPENSSL_PUT_ERROR(CIPHER, ERR_R_MALLOC_FAILURE);
-    return 0;
-  }
+  AEAD_TLS_CTX *tls_ctx = (AEAD_TLS_CTX *)&ctx->state;
   EVP_CIPHER_CTX_init(&tls_ctx->cipher_ctx);
   HMAC_CTX_init(&tls_ctx->hmac_ctx);
   assert(mac_key_len <= EVP_MAX_MD_SIZE);
@@ -84,13 +87,11 @@
   tls_ctx->mac_key_len = (uint8_t)mac_key_len;
   tls_ctx->implicit_iv = implicit_iv;
 
-  ctx->aead_state = tls_ctx;
   if (!EVP_CipherInit_ex(&tls_ctx->cipher_ctx, cipher, NULL, &key[mac_key_len],
                          implicit_iv ? &key[mac_key_len + enc_key_len] : NULL,
                          dir == evp_aead_seal) ||
       !HMAC_Init_ex(&tls_ctx->hmac_ctx, key, mac_key_len, md, NULL)) {
     aead_tls_cleanup(ctx);
-    ctx->aead_state = NULL;
     return 0;
   }
   EVP_CIPHER_CTX_set_padding(&tls_ctx->cipher_ctx, 0);
@@ -101,7 +102,7 @@
 static size_t aead_tls_tag_len(const EVP_AEAD_CTX *ctx, const size_t in_len,
                                const size_t extra_in_len) {
   assert(extra_in_len == 0);
-  AEAD_TLS_CTX *tls_ctx = (AEAD_TLS_CTX *)ctx->aead_state;
+  const AEAD_TLS_CTX *tls_ctx = (AEAD_TLS_CTX *)&ctx->state;
 
   const size_t hmac_len = HMAC_size(&tls_ctx->hmac_ctx);
   if (EVP_CIPHER_CTX_mode(&tls_ctx->cipher_ctx) != EVP_CIPH_CBC_MODE) {
@@ -125,7 +126,7 @@
                                  const uint8_t *extra_in,
                                  const size_t extra_in_len, const uint8_t *ad,
                                  const size_t ad_len) {
-  AEAD_TLS_CTX *tls_ctx = (AEAD_TLS_CTX *)ctx->aead_state;
+  AEAD_TLS_CTX *tls_ctx = (AEAD_TLS_CTX *)&ctx->state;
 
   if (!tls_ctx->cipher_ctx.encrypt) {
     // Unlike a normal AEAD, a TLS AEAD may only be used in one direction.
@@ -241,7 +242,7 @@
                          size_t max_out_len, const uint8_t *nonce,
                          size_t nonce_len, const uint8_t *in, size_t in_len,
                          const uint8_t *ad, size_t ad_len) {
-  AEAD_TLS_CTX *tls_ctx = (AEAD_TLS_CTX *)ctx->aead_state;
+  AEAD_TLS_CTX *tls_ctx = (AEAD_TLS_CTX *)&ctx->state;
 
   if (tls_ctx->cipher_ctx.encrypt) {
     // Unlike a normal AEAD, a TLS AEAD may only be used in one direction.
@@ -453,7 +454,7 @@
 
 static int aead_tls_get_iv(const EVP_AEAD_CTX *ctx, const uint8_t **out_iv,
                            size_t *out_iv_len) {
-  const AEAD_TLS_CTX *tls_ctx = (AEAD_TLS_CTX*) ctx->aead_state;
+  const AEAD_TLS_CTX *tls_ctx = (AEAD_TLS_CTX *)&ctx->state;
   const size_t iv_len = EVP_CIPHER_CTX_iv_length(&tls_ctx->cipher_ctx);
   if (iv_len <= 1) {
     return 0;
diff --git a/crypto/fipsmodule/cipher/e_aes.c b/crypto/fipsmodule/cipher/e_aes.c
index 0ced193..32b51a1 100644
--- a/crypto/fipsmodule/cipher/e_aes.c
+++ b/crypto/fipsmodule/cipher/e_aes.c
@@ -906,29 +906,30 @@
   return 1;
 }
 
+OPENSSL_COMPILE_ASSERT(sizeof(((EVP_AEAD_CTX *)NULL)->state) >=
+                           sizeof(struct aead_aes_gcm_ctx),
+                       AEAD_state_too_small);
+#if defined(__GNUC__) || defined(__clang__)
+  OPENSSL_COMPILE_ASSERT(
+      alignof(union evp_aead_ctx_st_state) >= alignof(struct aead_aes_gcm_ctx),
+      AEAD_state_insufficient_alignment);
+#endif
+
 static int aead_aes_gcm_init(EVP_AEAD_CTX *ctx, const uint8_t *key,
                              size_t key_len, size_t requested_tag_len) {
-  struct aead_aes_gcm_ctx *gcm_ctx;
-  gcm_ctx = OPENSSL_malloc(sizeof(struct aead_aes_gcm_ctx));
-  if (gcm_ctx == NULL) {
-    return 0;
-  }
+  struct aead_aes_gcm_ctx *gcm_ctx = (struct aead_aes_gcm_ctx *) &ctx->state;
 
   size_t actual_tag_len;
   if (!aead_aes_gcm_init_impl(gcm_ctx, &actual_tag_len, key, key_len,
                               requested_tag_len)) {
-    OPENSSL_free(gcm_ctx);
     return 0;
   }
 
-  ctx->aead_state = gcm_ctx;
   ctx->tag_len = actual_tag_len;
   return 1;
 }
 
-static void aead_aes_gcm_cleanup(EVP_AEAD_CTX *ctx) {
-  OPENSSL_free(ctx->aead_state);
-}
+static void aead_aes_gcm_cleanup(EVP_AEAD_CTX *ctx) {}
 
 static int aead_aes_gcm_seal_scatter(const EVP_AEAD_CTX *ctx, uint8_t *out,
                                      uint8_t *out_tag, size_t *out_tag_len,
@@ -938,7 +939,7 @@
                                      const uint8_t *extra_in,
                                      size_t extra_in_len,
                                      const uint8_t *ad, size_t ad_len) {
-  const struct aead_aes_gcm_ctx *gcm_ctx = ctx->aead_state;
+  struct aead_aes_gcm_ctx *gcm_ctx = (struct aead_aes_gcm_ctx *) &ctx->state;
 
   if (extra_in_len + ctx->tag_len < ctx->tag_len) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_TOO_LARGE);
@@ -999,7 +1000,7 @@
                                     const uint8_t *in, size_t in_len,
                                     const uint8_t *in_tag, size_t in_tag_len,
                                     const uint8_t *ad, size_t ad_len) {
-  const struct aead_aes_gcm_ctx *gcm_ctx = ctx->aead_state;
+  struct aead_aes_gcm_ctx *gcm_ctx = (struct aead_aes_gcm_ctx *) &ctx->state;
   uint8_t tag[EVP_AEAD_AES_GCM_TAG_LEN];
 
   if (nonce_len == 0) {
@@ -1078,24 +1079,28 @@
   uint64_t min_next_nonce;
 };
 
+OPENSSL_COMPILE_ASSERT(sizeof(((EVP_AEAD_CTX *)NULL)->state) >=
+                           sizeof(struct aead_aes_gcm_tls12_ctx),
+                       AEAD_state_too_small);
+#if defined(__GNUC__) || defined(__clang__)
+OPENSSL_COMPILE_ASSERT(alignof(union evp_aead_ctx_st_state) >=
+                           alignof(struct aead_aes_gcm_tls12_ctx),
+                       AEAD_state_insufficient_alignment);
+#endif
+
 static int aead_aes_gcm_tls12_init(EVP_AEAD_CTX *ctx, const uint8_t *key,
                                    size_t key_len, size_t requested_tag_len) {
-  struct aead_aes_gcm_tls12_ctx *gcm_ctx;
-  gcm_ctx = OPENSSL_malloc(sizeof(struct aead_aes_gcm_tls12_ctx));
-  if (gcm_ctx == NULL) {
-    return 0;
-  }
+  struct aead_aes_gcm_tls12_ctx *gcm_ctx =
+      (struct aead_aes_gcm_tls12_ctx *) &ctx->state;
 
   gcm_ctx->min_next_nonce = 0;
 
   size_t actual_tag_len;
   if (!aead_aes_gcm_init_impl(&gcm_ctx->gcm_ctx, &actual_tag_len, key, key_len,
                               requested_tag_len)) {
-    OPENSSL_free(gcm_ctx);
     return 0;
   }
 
-  ctx->aead_state = gcm_ctx;
   ctx->tag_len = actual_tag_len;
   return 1;
 }
@@ -1105,7 +1110,9 @@
     size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
     size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
     size_t extra_in_len, const uint8_t *ad, size_t ad_len) {
-  struct aead_aes_gcm_tls12_ctx *gcm_ctx = ctx->aead_state;
+  struct aead_aes_gcm_tls12_ctx *gcm_ctx =
+      (struct aead_aes_gcm_tls12_ctx *) &ctx->state;
+
   if (nonce_len != 12) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_UNSUPPORTED_NONCE_SIZE);
     return 0;
@@ -1166,13 +1173,19 @@
   uint8_t first;
 };
 
+OPENSSL_COMPILE_ASSERT(sizeof(((EVP_AEAD_CTX *)NULL)->state) >=
+                           sizeof(struct aead_aes_gcm_tls13_ctx),
+                       AEAD_state_too_small);
+#if defined(__GNUC__) || defined(__clang__)
+OPENSSL_COMPILE_ASSERT(alignof(union evp_aead_ctx_st_state) >=
+                           alignof(struct aead_aes_gcm_tls13_ctx),
+                       AEAD_state_insufficient_alignment);
+#endif
+
 static int aead_aes_gcm_tls13_init(EVP_AEAD_CTX *ctx, const uint8_t *key,
                                    size_t key_len, size_t requested_tag_len) {
-  struct aead_aes_gcm_tls13_ctx *gcm_ctx;
-  gcm_ctx = OPENSSL_malloc(sizeof(struct aead_aes_gcm_tls13_ctx));
-  if (gcm_ctx == NULL) {
-    return 0;
-  }
+  struct aead_aes_gcm_tls13_ctx *gcm_ctx =
+      (struct aead_aes_gcm_tls13_ctx *) &ctx->state;
 
   gcm_ctx->min_next_nonce = 0;
   gcm_ctx->first = 1;
@@ -1180,11 +1193,9 @@
   size_t actual_tag_len;
   if (!aead_aes_gcm_init_impl(&gcm_ctx->gcm_ctx, &actual_tag_len, key, key_len,
                               requested_tag_len)) {
-    OPENSSL_free(gcm_ctx);
     return 0;
   }
 
-  ctx->aead_state = gcm_ctx;
   ctx->tag_len = actual_tag_len;
   return 1;
 }
@@ -1194,7 +1205,9 @@
     size_t *out_tag_len, size_t max_out_tag_len, const uint8_t *nonce,
     size_t nonce_len, const uint8_t *in, size_t in_len, const uint8_t *extra_in,
     size_t extra_in_len, const uint8_t *ad, size_t ad_len) {
-  struct aead_aes_gcm_tls13_ctx *gcm_ctx = ctx->aead_state;
+  struct aead_aes_gcm_tls13_ctx *gcm_ctx =
+      (struct aead_aes_gcm_tls13_ctx *) &ctx->state;
+
   if (nonce_len != 12) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_UNSUPPORTED_NONCE_SIZE);
     return 0;
diff --git a/include/openssl/aead.h b/include/openssl/aead.h
index f19344e..963b2c7 100644
--- a/include/openssl/aead.h
+++ b/include/openssl/aead.h
@@ -170,13 +170,16 @@
 
 // AEAD operations.
 
+union evp_aead_ctx_st_state {
+  uint8_t opaque[580];
+  uint64_t alignment;
+};
+
 // An EVP_AEAD_CTX represents an AEAD algorithm configured with a specific key
 // and message-independent IV.
 typedef struct evp_aead_ctx_st {
   const EVP_AEAD *aead;
-  // aead_state is an opaque pointer to whatever state the AEAD needs to
-  // maintain.
-  void *aead_state;
+  union evp_aead_ctx_st_state state;
   // tag_len may contain the actual length of the authentication tag if it is
   // known at initialization time.
   uint8_t tag_len;