Allow SSL_HANDSHAKE::key_shares to vary in size

This refactors hs->key_shares to a vector rather than a fixed array of
size 2. There is at most a key share for every implemented NamedGroup,
so the size is capped at the number of named groups. This is needed for
allowing the caller to explicitly specify client key shares. There's no
behavior change intended in this CL. We still use at most 2 key shares
(by default; that will change in the next CL).

Change-Id: I0eeb218c59ec1f0e8e062aeb5f833102c40d7bf6
Bug: 437414371
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/81487
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Lily Chen <chlily@google.com>
diff --git a/ssl/extensions.cc b/ssl/extensions.cc
index dcdd5f0..a09594e 100644
--- a/ssl/extensions.cc
+++ b/ssl/extensions.cc
@@ -2188,8 +2188,7 @@
 
 bool ssl_setup_key_shares(SSL_HANDSHAKE *hs, uint16_t override_group_id) {
   SSL *const ssl = hs->ssl;
-  hs->key_shares[0].reset();
-  hs->key_shares[1].reset();
+  hs->key_shares.clear();
   hs->key_share_bytes.Reset();
 
   // If offering a PAKE, do not set up key shares. We do not currently support
@@ -2235,20 +2234,23 @@
   }
 
   CBB key_exchange;
-  hs->key_shares[0] = SSLKeyShare::Create(group_id);
-  if (!hs->key_shares[0] ||  //
-      !CBB_add_u16(cbb.get(), group_id) ||
-      !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) ||
-      !hs->key_shares[0]->Generate(&key_exchange)) {
-    return false;
+  {
+    UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(group_id);
+    if (key_share == nullptr || !CBB_add_u16(cbb.get(), group_id) ||
+        !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) ||
+        !key_share->Generate(&key_exchange) ||
+        !hs->key_shares.TryPushBack(std::move(key_share))) {
+      return false;
+    }
   }
 
   if (second_group_id != 0) {
-    hs->key_shares[1] = SSLKeyShare::Create(second_group_id);
-    if (!hs->key_shares[1] ||  //
-        !CBB_add_u16(cbb.get(), second_group_id) ||
+    // TODO(chlily): Fix temporary code duplication.
+    UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(second_group_id);
+    if (key_share == nullptr || !CBB_add_u16(cbb.get(), second_group_id) ||
         !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) ||
-        !hs->key_shares[1]->Generate(&key_exchange)) {
+        !key_share->Generate(&key_exchange) ||
+        !hs->key_shares.TryPushBack(std::move(key_share))) {
       return false;
     }
   }
@@ -2282,13 +2284,9 @@
 bool ssl_ext_key_share_parse_serverhello(SSL_HANDSHAKE *hs,
                                          Array<uint8_t> *out_secret,
                                          uint8_t *out_alert, CBS *contents) {
-  if (hs->key_shares[0] == nullptr) {
-    // If we did not offer key shares, the extension should have been rejected
-    // as unsolicited.
-    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
-    *out_alert = SSL_AD_INTERNAL_ERROR;
-    return false;
-  }
+  // If we did not offer key shares, the extension should have been rejected
+  // as unsolicited.
+  assert(!hs->key_shares.empty());
 
   CBS ciphertext;
   uint16_t group_id;
@@ -2300,24 +2298,23 @@
     return false;
   }
 
-  SSLKeyShare *key_share = hs->key_shares[0].get();
-  if (key_share->GroupID() != group_id) {
-    if (!hs->key_shares[1] || hs->key_shares[1]->GroupID() != group_id) {
-      *out_alert = SSL_AD_ILLEGAL_PARAMETER;
-      OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_CURVE);
-      return false;
-    }
-    key_share = hs->key_shares[1].get();
+  const auto key_share_it =
+      std::find_if(hs->key_shares.begin(), hs->key_shares.end(),
+                   [group_id](const auto &key_share) {
+                     return key_share->GroupID() == group_id;
+                   });
+  if (key_share_it == hs->key_shares.end()) {
+    *out_alert = SSL_AD_ILLEGAL_PARAMETER;
+    OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_CURVE);
+    return false;
   }
-
-  if (!key_share->Decap(out_secret, out_alert, ciphertext)) {
+  if (!(*key_share_it)->Decap(out_secret, out_alert, ciphertext)) {
     *out_alert = SSL_AD_INTERNAL_ERROR;
     return false;
   }
 
   hs->new_session->group_id = group_id;
-  hs->key_shares[0].reset();
-  hs->key_shares[1].reset();
+  hs->key_shares.clear();
   return true;
 }
 
diff --git a/ssl/handoff.cc b/ssl/handoff.cc
index 4c4ad35..5598b1c 100644
--- a/ssl/handoff.cc
+++ b/ssl/handoff.cc
@@ -756,9 +756,10 @@
         !CBS_get_asn1(&key_share, &private_key, CBS_ASN1_OCTETSTRING)) {
       return false;
     }
-    hs->key_shares[0] = SSLKeyShare::Create(group_id);
-    if (!hs->key_shares[0] ||
-        !hs->key_shares[0]->DeserializePrivateKey(&private_key)) {
+    UniquePtr<SSLKeyShare> ssl_key_share = SSLKeyShare::Create(group_id);
+    if (ssl_key_share == nullptr ||
+        !ssl_key_share->DeserializePrivateKey(&private_key) ||
+        !hs->key_shares.TryPushBack(std::move(ssl_key_share))) {
       return false;
     }
   }
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index 77d02fd..ff760be 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -665,8 +665,7 @@
   }
 
   // Clear some TLS 1.3 state that no longer needs to be retained.
-  hs->key_shares[0].reset();
-  hs->key_shares[1].reset();
+  hs->key_shares.clear();
   ssl_done_writing_client_hello(hs);
 
   // TLS 1.2 handshakes cannot accept ECH.
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index 9f660f6..1661430 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -1025,10 +1025,12 @@
 
     if (alg_k & SSL_kECDHE) {
       assert(hs->new_session->group_id != 0);
-      hs->key_shares[0] = SSLKeyShare::Create(hs->new_session->group_id);
-      if (!hs->key_shares[0] ||                                  //
-          !CBB_add_u8(cbb.get(), NAMED_CURVE_TYPE) ||            //
-          !CBB_add_u16(cbb.get(), hs->new_session->group_id) ||  //
+      UniquePtr<SSLKeyShare> ssl_key_share =
+          SSLKeyShare::Create(hs->new_session->group_id);
+      if (ssl_key_share == nullptr ||                               //
+          !hs->key_shares.TryPushBack(std::move(ssl_key_share)) ||  //
+          !CBB_add_u8(cbb.get(), NAMED_CURVE_TYPE) ||               //
+          !CBB_add_u16(cbb.get(), hs->new_session->group_id) ||     //
           !CBB_add_u8_length_prefixed(cbb.get(), &child)) {
         return ssl_hs_error;
       }
@@ -1407,8 +1409,7 @@
     }
 
     // The key exchange state may now be discarded.
-    hs->key_shares[0].reset();
-    hs->key_shares[1].reset();
+    hs->key_shares.clear();
   } else if (!(alg_k & SSL_kPSK)) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
diff --git a/ssl/internal.h b/ssl/internal.h
index 9cff133..37eb54e 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -954,6 +954,9 @@
 // NamedGroups returns all supported groups.
 Span<const NamedGroup> NamedGroups();
 
+// kNumNamedGroups is the number of supported groups.
+constexpr size_t kNumNamedGroups = 7u;
+
 // DefaultSupportedGroupIds returns the list of IDs for the default groups that
 // are supported when the caller hasn't explicitly configured supported groups.
 Span<const uint16_t> DefaultSupportedGroupIds();
@@ -1766,10 +1769,11 @@
   // error, if |wait| is |ssl_hs_error|, is the error the handshake failed on.
   UniquePtr<ERR_SAVE_STATE> error;
 
-  // key_shares are the current key exchange instances. The second is only used
-  // as a client if we believe that we should offer two key shares in a
-  // ClientHello.
-  UniquePtr<SSLKeyShare> key_shares[2];
+  // key_shares are the current key exchange instances, in preference order. Any
+  // members of this vector must be non-null. By default, no more than two are
+  // used, and the second is only used as a client if we believe that we should
+  // offer two key shares in a ClientHello.
+  InplaceVector<UniquePtr<SSLKeyShare>, kNumNamedGroups> key_shares;
 
   // transcript is the current handshake transcript.
   SSLTranscript transcript;
diff --git a/ssl/ssl_key_share.cc b/ssl/ssl_key_share.cc
index 0014d0c..9f885d5 100644
--- a/ssl/ssl_key_share.cc
+++ b/ssl/ssl_key_share.cc
@@ -445,6 +445,9 @@
     {NID_MLKEM1024, SSL_GROUP_MLKEM1024, "MLKEM1024", ""},
 };
 
+static_assert(std::size(kNamedGroups) == kNumNamedGroups,
+              "kNamedGroups size mismatch");
+
 }  // namespace
 
 Span<const NamedGroup> NamedGroups() { return kNamedGroups; }
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index e18ad11..dc417b9 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -18,6 +18,7 @@
 #include <limits.h>
 #include <string.h>
 
+#include <algorithm>
 #include <utility>
 
 #include <openssl/bytestring.h>
@@ -306,8 +307,10 @@
 
     // Check that the HelloRetryRequest does not request a key share that was
     // provided in the initial ClientHello.
-    if (hs->key_shares[0]->GroupID() == group_id ||
-        (hs->key_shares[1] && hs->key_shares[1]->GroupID() == group_id)) {
+    if (std::find_if(hs->key_shares.begin(), hs->key_shares.end(),
+                     [group_id](const auto &hs_key_share) {
+                       return hs_key_share->GroupID() == group_id;
+                     }) != hs->key_shares.end()) {
       ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
       OPENSSL_PUT_ERROR(SSL, SSL_R_WRONG_CURVE);
       return ssl_hs_error;
@@ -426,7 +429,7 @@
       ssl_session_get_type(ssl->session.get()) ==
           SSLSessionType::kPreSharedKey &&
       ssl->s3->ech_status != ssl_ech_rejected;
-  SSLExtension key_share(TLSEXT_TYPE_key_share, hs->key_shares[0] != nullptr),
+  SSLExtension key_share(TLSEXT_TYPE_key_share, !hs->key_shares.empty()),
       pake_share(TLSEXT_TYPE_pake, hs->pake_prover != nullptr),
       pre_shared_key(TLSEXT_TYPE_pre_shared_key, pre_shared_key_allowed),
       supported_versions(TLSEXT_TYPE_supported_versions);