Make AES-GCM AEADs support the optional second input argument to seal_scatter.

Change-Id: I8cf7c7ef9c3fdcc2cd1bf6669fbcd616f4c0e0ef
Reviewed-on: https://boringssl-review.googlesource.com/17364
Commit-Queue: Adam Langley <agl@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/cipher_extra/aead_test.cc b/crypto/cipher_extra/aead_test.cc
index 6536e34..cce432c 100644
--- a/crypto/cipher_extra/aead_test.cc
+++ b/crypto/cipher_extra/aead_test.cc
@@ -23,6 +23,7 @@
 #include <openssl/cipher.h>
 #include <openssl/err.h>
 
+#include "../fipsmodule/cipher/internal.h"
 #include "../internal.h"
 #include "../test/file_test.h"
 #include "../test/test_util.h"
@@ -208,6 +209,55 @@
   });
 }
 
+TEST_P(PerAEADTest, TestExtraInput) {
+  const KnownAEAD &aead_config = GetParam();
+  if (!aead()->seal_scatter_supports_extra_in) {
+    return;
+  }
+
+  const std::string test_vectors =
+      "crypto/cipher_extra/test/" + std::string(aead_config.test_vectors);
+  FileTestGTest(test_vectors.c_str(), [&](FileTest *t) {
+    if (t->HasAttribute("NO_SEAL") ||
+        t->HasAttribute("FAILS")) {
+      t->SkipCurrent();
+      return;
+    }
+
+    std::vector<uint8_t> key, nonce, in, ad, ct, tag;
+    ASSERT_TRUE(t->GetBytes(&key, "KEY"));
+    ASSERT_TRUE(t->GetBytes(&nonce, "NONCE"));
+    ASSERT_TRUE(t->GetBytes(&in, "IN"));
+    ASSERT_TRUE(t->GetBytes(&ad, "AD"));
+    ASSERT_TRUE(t->GetBytes(&ct, "CT"));
+    ASSERT_TRUE(t->GetBytes(&tag, "TAG"));
+
+    bssl::ScopedEVP_AEAD_CTX ctx;
+    ASSERT_TRUE(EVP_AEAD_CTX_init(ctx.get(), aead(), key.data(), key.size(),
+                                  tag.size(), nullptr));
+    std::vector<uint8_t> out_tag(EVP_AEAD_max_overhead(aead()) + in.size());
+    std::vector<uint8_t> out(in.size());
+
+    for (size_t extra_in_size = 0; extra_in_size < in.size(); extra_in_size++) {
+      size_t tag_bytes_written;
+      ASSERT_TRUE(EVP_AEAD_CTX_seal_scatter(
+          ctx.get(), out.data(), out_tag.data(), &tag_bytes_written,
+          out_tag.size(), nonce.data(), nonce.size(), in.data(),
+          in.size() - extra_in_size, in.data() + in.size() - extra_in_size,
+          extra_in_size, ad.data(), ad.size()));
+
+      ASSERT_EQ(tag_bytes_written, extra_in_size + tag.size());
+
+      memcpy(out.data() + in.size() - extra_in_size, out_tag.data(),
+             extra_in_size);
+
+      EXPECT_EQ(Bytes(ct), Bytes(out.data(), in.size()));
+      EXPECT_EQ(Bytes(tag), Bytes(out_tag.data() + extra_in_size,
+                                  tag_bytes_written - extra_in_size));
+    }
+  });
+}
+
 TEST_P(PerAEADTest, TestVectorScatterGather) {
   std::string test_vectors = "crypto/cipher_extra/test/";
   const KnownAEAD &aead_config = GetParam();
@@ -271,7 +321,7 @@
       int err = ERR_peek_error();
       if (ERR_GET_LIB(err) == ERR_LIB_CIPHER &&
           ERR_GET_REASON(err) == CIPHER_R_CTRL_NOT_IMPLEMENTED) {
-          (void)t->HasAttribute("FAILS");  // All attributes need to be used.
+          t->SkipCurrent();
           return;
         }
     }
diff --git a/crypto/fipsmodule/cipher/e_aes.c b/crypto/fipsmodule/cipher/e_aes.c
index 31d9877..7c7521b 100644
--- a/crypto/fipsmodule/cipher/e_aes.c
+++ b/crypto/fipsmodule/cipher/e_aes.c
@@ -1213,6 +1213,14 @@
   const struct aead_aes_gcm_ctx *gcm_ctx = ctx->aead_state;
   GCM128_CONTEXT gcm;
 
+  if (extra_in_len + ctx->tag_len < ctx->tag_len) {
+    OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_TOO_LARGE);
+    return 0;
+  }
+  if (max_out_tag_len < ctx->tag_len + extra_in_len) {
+    OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_BUFFER_TOO_SMALL);
+    return 0;
+  }
   if (nonce_len == 0) {
     OPENSSL_PUT_ERROR(CIPHER, CIPHER_R_INVALID_NONCE_SIZE);
     return 0;
@@ -1243,8 +1251,22 @@
     }
   }
 
-  CRYPTO_gcm128_tag(&gcm, out_tag, ctx->tag_len);
-  *out_tag_len = ctx->tag_len;
+  if (extra_in_len) {
+    if (gcm_ctx->ctr) {
+      if (!CRYPTO_gcm128_encrypt_ctr32(&gcm, key, extra_in, out_tag,
+                                       extra_in_len, gcm_ctx->ctr)) {
+        return 0;
+      }
+    } else {
+      if (!CRYPTO_gcm128_encrypt(&gcm, key, extra_in, out_tag, extra_in_len)) {
+        return 0;
+      }
+    }
+  }
+
+  CRYPTO_gcm128_tag(&gcm, out_tag + extra_in_len, ctx->tag_len);
+  *out_tag_len = ctx->tag_len + extra_in_len;
+
   return 1;
 }
 
@@ -1303,6 +1325,8 @@
   out->nonce_len = 12;
   out->overhead = EVP_AEAD_AES_GCM_TAG_LEN;
   out->max_tag_len = EVP_AEAD_AES_GCM_TAG_LEN;
+  out->seal_scatter_supports_extra_in = 1;
+
   out->init = aead_aes_gcm_init;
   out->cleanup = aead_aes_gcm_cleanup;
   out->seal_scatter = aead_aes_gcm_seal_scatter;
@@ -1316,6 +1340,8 @@
   out->nonce_len = 12;
   out->overhead = EVP_AEAD_AES_GCM_TAG_LEN;
   out->max_tag_len = EVP_AEAD_AES_GCM_TAG_LEN;
+  out->seal_scatter_supports_extra_in = 1;
+
   out->init = aead_aes_gcm_init;
   out->cleanup = aead_aes_gcm_cleanup;
   out->seal_scatter = aead_aes_gcm_seal_scatter;
@@ -1386,6 +1412,8 @@
   out->nonce_len = 12;
   out->overhead = EVP_AEAD_AES_GCM_TAG_LEN;
   out->max_tag_len = EVP_AEAD_AES_GCM_TAG_LEN;
+  out->seal_scatter_supports_extra_in = 1;
+
   out->init = aead_aes_gcm_tls12_init;
   out->cleanup = aead_aes_gcm_tls12_cleanup;
   out->seal_scatter = aead_aes_gcm_tls12_seal_scatter;
@@ -1399,6 +1427,8 @@
   out->nonce_len = 12;
   out->overhead = EVP_AEAD_AES_GCM_TAG_LEN;
   out->max_tag_len = EVP_AEAD_AES_GCM_TAG_LEN;
+  out->seal_scatter_supports_extra_in = 1;
+
   out->init = aead_aes_gcm_tls12_init;
   out->cleanup = aead_aes_gcm_tls12_cleanup;
   out->seal_scatter = aead_aes_gcm_tls12_seal_scatter;
diff --git a/crypto/test/file_test.cc b/crypto/test/file_test.cc
index 5bc5fb4..c85fac6 100644
--- a/crypto/test/file_test.cc
+++ b/crypto/test/file_test.cc
@@ -465,3 +465,7 @@
 
   return failed ? 1 : 0;
 }
+
+void FileTest::SkipCurrent() {
+  ClearTest();
+}
diff --git a/crypto/test/file_test.h b/crypto/test/file_test.h
index 8c476c4..fd1dcd7 100644
--- a/crypto/test/file_test.h
+++ b/crypto/test/file_test.h
@@ -184,6 +184,9 @@
   // instructions.
   void InjectInstruction(const std::string &key, const std::string &value);
 
+  // SkipCurrent passes the current test case. Unused attributes are ignored.
+  void SkipCurrent();
+
  private:
   void ClearTest();
   void ClearInstructions();