Add API for configuring client key shares

This change introduces a new API to allow callers to configure the exact
set of key shares to be sent in a client's key_share extension. If the
supported groups for a connection are modified, any previously-selected
key shares are cleared if they are no longer compatible. Clients are
allowed to configure an empty list of key shares, which results in
always taking a round-trip for HelloRetryRequest.

Bug: 437414371
Change-Id: Ibd2a9b217fa3c746dec194a027d3ec45a81bb578
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/81567
Commit-Queue: Lily Chen <chlily@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/crypto/mem_internal.h b/crypto/mem_internal.h
index 20b45fd..5ba44dc 100644
--- a/crypto/mem_internal.h
+++ b/crypto/mem_internal.h
@@ -379,6 +379,8 @@
 template <typename T, size_t N>
 class InplaceVector {
  public:
+  using value_type = std::remove_cv_t<T>;
+
   InplaceVector() = default;
   InplaceVector(const InplaceVector &other) { *this = other; }
   InplaceVector(InplaceVector &&other) { *this = std::move(other); }
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 79eb0c7..81de2fd 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -2605,6 +2605,47 @@
 OPENSSL_EXPORT int SSL_get_negotiated_group(const SSL *ssl);
 
 
+// Client key shares.
+//
+// The key_share extension in TLS 1.3 (RFC 8446 section 4.2.8) may be sent in
+// the initial ClientHello to provide key exchange parameters for a subset of
+// the groups offered in the client's supported_groups extension, in hopes of
+// saving a round-trip by having proactively started a key exchange for the
+// ultimately-negotiated group.
+//
+// If not otherwise configured, the default client key share selection logic
+// outputs key shares for up to two supported groups, at most one of which is
+// post-quantum.
+
+// SSL_set1_client_key_shares, when called by a client before the handshake,
+// configures |ssl| to send a key_share extension in the initial ClientHello
+// containing exactly the groups given by |group_ids|, in the order given. Each
+// member of |group_ids| should be one of the |SSL_GROUP_*| constants, and they
+// must be unique. This function returns one on success and zero on failure.
+//
+// If non-empty, the sequence of |group_ids| must be a (not necessarily
+// contiguous) subsequence of the groups supported by |ssl|, which may have been
+// configured explicitly on |ssl| or its context, or populated by default.
+// Caller should finish configuring the group list before calling this function.
+// Changing the supported groups for |ssl| after having set client key shares
+// will result in the key share selections being reset if this constraint no
+// longer holds.
+//
+// Setting an empty sequence of |group_ids| results in an empty client
+// key_share, which will cause the handshake to always take an extra round-trip
+// for HelloRetryRequest.
+//
+// An extra round-trip will be needed if the server's choice of group is not
+// among the key shares sent; conversely, sending any key shares other than the
+// server's choice wastes CPU and bandwidth (the latter is particularly costly
+// for post-quantum key exchanges). To avoid these sub-optimal outcomes,
+// key shares should be chosen such that they are likely to be supported by the
+// peer server.
+OPENSSL_EXPORT int SSL_set1_client_key_shares(SSL *ssl,
+                                              const uint16_t *group_ids,
+                                              size_t num_group_ids);
+
+
 // Certificate verification.
 //
 // SSL may authenticate either endpoint with an X.509 certificate. Typically
diff --git a/ssl/extensions.cc b/ssl/extensions.cc
index a09594e..2169966 100644
--- a/ssl/extensions.cc
+++ b/ssl/extensions.cc
@@ -2197,6 +2197,45 @@
     return true;
   }
 
+  const Span<const uint16_t> supported_group_list =
+      hs->config->supported_group_list;
+  if (supported_group_list.empty()) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_NO_GROUPS_SPECIFIED);
+    return false;
+  }
+
+  InplaceVector<uint16_t, 2> default_key_shares;
+
+  Span<const uint16_t> selected_key_shares;
+
+  // Determine the key shares to send.
+  if (override_group_id != 0) {
+    assert(std::find(supported_group_list.begin(), supported_group_list.end(),
+                     override_group_id) != supported_group_list.end());
+    selected_key_shares = Span(&override_group_id, 1u);
+  } else if (ssl->config->client_key_share_selections.has_value()) {
+    selected_key_shares = *(ssl->config->client_key_share_selections);
+  } else {
+    // By default, predict the most preferred group.
+    if (!default_key_shares.TryPushBack(supported_group_list[0])) {
+      return false;
+    }
+    // We'll try to include one post-quantum and one classical initial key
+    // share.
+    for (size_t i = 1; i < supported_group_list.size(); i++) {
+      if (is_post_quantum_group(default_key_shares[0]) ==
+          is_post_quantum_group(supported_group_list[i])) {
+        continue;
+      }
+      if (!default_key_shares.TryPushBack(supported_group_list[i])) {
+        return false;
+      }
+      assert(default_key_shares[1] != default_key_shares[0]);
+      break;
+    }
+    selected_key_shares = default_key_shares;
+  }
+
   bssl::ScopedCBB cbb;
   if (!CBB_init(cbb.get(), 64)) {
     return false;
@@ -2211,45 +2250,13 @@
     }
   }
 
-  uint16_t group_id = override_group_id;
-  uint16_t second_group_id = 0;
-  if (override_group_id == 0) {
-    // Predict the most preferred group.
-    Span<const uint16_t> groups = hs->config->supported_group_list;
-    if (groups.empty()) {
-      OPENSSL_PUT_ERROR(SSL, SSL_R_NO_GROUPS_SPECIFIED);
-      return false;
-    }
-
-    group_id = groups[0];
-
-    // We'll try to include one post-quantum and one classical initial key
-    // share.
-    for (size_t i = 1; i < groups.size() && second_group_id == 0; i++) {
-      if (is_post_quantum_group(group_id) != is_post_quantum_group(groups[i])) {
-        second_group_id = groups[i];
-        assert(second_group_id != group_id);
-      }
-    }
-  }
-
   CBB key_exchange;
-  {
+  for (const uint16_t group_id : selected_key_shares) {
     UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(group_id);
-    if (key_share == nullptr || !CBB_add_u16(cbb.get(), group_id) ||
-        !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) ||
-        !key_share->Generate(&key_exchange) ||
-        !hs->key_shares.TryPushBack(std::move(key_share))) {
-      return false;
-    }
-  }
-
-  if (second_group_id != 0) {
-    // TODO(chlily): Fix temporary code duplication.
-    UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(second_group_id);
-    if (key_share == nullptr || !CBB_add_u16(cbb.get(), second_group_id) ||
-        !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) ||
-        !key_share->Generate(&key_exchange) ||
+    if (key_share == nullptr ||                                    //
+        !CBB_add_u16(cbb.get(), group_id) ||                       //
+        !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) ||  //
+        !key_share->Generate(&key_exchange) ||                     //
         !hs->key_shares.TryPushBack(std::move(key_share))) {
       return false;
     }
@@ -2267,7 +2274,12 @@
     return true;
   }
 
-  assert(!hs->key_share_bytes.empty());
+  // The caller may explicitly configure empty key shares to request a
+  // HelloRetryRequest.
+  assert(!hs->key_share_bytes.empty() ||
+         (hs->config->client_key_share_selections.has_value() &&
+          hs->config->client_key_share_selections->empty()));
+
   CBB contents, kse_bytes;
   if (!CBB_add_u16(out_compressible, TLSEXT_TYPE_key_share) ||
       !CBB_add_u16_length_prefixed(out_compressible, &contents) ||
diff --git a/ssl/internal.h b/ssl/internal.h
index 37eb54e..d72f7a5 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1770,9 +1770,7 @@
   UniquePtr<ERR_SAVE_STATE> error;
 
   // key_shares are the current key exchange instances, in preference order. Any
-  // members of this vector must be non-null. By default, no more than two are
-  // used, and the second is only used as a client if we believe that we should
-  // offer two key shares in a ClientHello.
+  // members of this vector must be non-null.
   InplaceVector<UniquePtr<SSLKeyShare>, kNumNamedGroups> key_shares;
 
   // transcript is the current handshake transcript.
@@ -2130,9 +2128,19 @@
 bool ssl_setup_extension_permutation(SSL_HANDSHAKE *hs);
 
 // ssl_setup_key_shares computes client key shares and saves them in |hs|. It
-// returns true on success and false on failure. If |override_group_id| is zero,
-// it offers the default groups, including GREASE. If it is non-zero, it offers
-// a single key share of the specified group.
+// returns true on success and false on failure. In order of precedence:
+//
+// - If |override_group_id| is non-zero, it offers a single key share of the
+//   specified group.
+//
+// - If key shares were previously specified by the caller via
+//   |SSL_set_client_key_shares|, those are used.
+//
+// - If |override_group_id| is zero and no selections were made by the caller
+//   via |SSL_set_client_key_shares|, it selects the first supported group and
+//   may select a second if at most one of the two is a post-quantum group.
+//
+// GREASE will be included if enabled, when |override_group_id| is zero.
 bool ssl_setup_key_shares(SSL_HANDSHAKE *hs, uint16_t override_group_id);
 
 // ssl_setup_pake_shares computes the client PAKE shares and saves them in |hs|.
@@ -3275,7 +3283,19 @@
   // Trust anchor IDs to be requested in the trust_anchors extension.
   std::optional<Array<uint8_t>> requested_trust_anchors;
 
-  Array<uint16_t> supported_group_list;  // our list
+  // Our list of supported groups. If this list is modified, for a client,
+  // |client_key_share_selections| must be reset if the key shares are no longer
+  // a valid subsequence of the supported group list.
+  Array<uint16_t> supported_group_list;
+
+  // For a client, this may contain a subsequence of the group IDs in
+  // |suppported_group_list|, which gives the groups for which key shares should
+  // be sent in the client's key_share extension. This is non-nullopt iff
+  // |SSL_set_client_key_shares| was successfully called to configure key
+  // shares. If non-nullopt, these groups are in the same order as they appear
+  // in |supported_group_list|, and may not contain duplicates.
+  std::optional<InplaceVector<uint16_t, kNumNamedGroups>>
+      client_key_share_selections;
 
   // channel_id_private is the client's Channel ID private key, or null if
   // Channel ID should not be offered on this connection.
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 2e0db35..532744f 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -1844,6 +1844,37 @@
   return check_no_duplicates(group_ids);
 }
 
+// validate_key_shares returns whether the `requested_key_shares` are free of
+// duplicates and are a (correctly ordered) subsequence of the supported
+// `groups`.
+static bool validate_key_shares(Span<const uint16_t> requested_key_shares,
+                                Span<const uint16_t> groups) {
+  if (!check_no_duplicates(requested_key_shares)) {
+    return false;
+  }
+  if (requested_key_shares.size() > groups.size()) {
+    return false;
+  }
+  size_t key_shares_idx = 0u, groups_idx = 0u;
+  while (key_shares_idx < requested_key_shares.size() &&
+         groups_idx < groups.size()) {
+    if (requested_key_shares[key_shares_idx] == groups[groups_idx++]) {
+      ++key_shares_idx;
+    }
+  }
+  return key_shares_idx == requested_key_shares.size();
+}
+
+static void clear_key_shares_if_invalid(SSL_CONFIG *config) {
+  if (!config->client_key_share_selections) {
+    return;
+  }
+  if (!validate_key_shares(*(config->client_key_share_selections),
+                           config->supported_group_list)) {
+    config->client_key_share_selections.reset();
+  }
+}
+
 int SSL_CTX_set1_group_ids(SSL_CTX *ctx, const uint16_t *group_ids,
                            size_t num_group_ids) {
   auto span = Span(group_ids, num_group_ids);
@@ -1862,8 +1893,12 @@
   if (span.empty()) {
     span = DefaultSupportedGroupIds();
   }
-  return check_group_ids(span) &&
-         ssl->config->supported_group_list.CopyFrom(span);
+  if (check_group_ids(span) &&
+      ssl->config->supported_group_list.CopyFrom(span)) {
+    clear_key_shares_if_invalid(ssl->config.get());
+    return 1;
+  }
+  return 0;
 }
 
 static bool ssl_nids_to_group_ids(Array<uint16_t> *out_group_ids,
@@ -1899,8 +1934,12 @@
   if (!ssl->config) {
     return 0;
   }
-  return ssl_nids_to_group_ids(&ssl->config->supported_group_list,
-                               Span(groups, num_groups));
+  if (ssl_nids_to_group_ids(&ssl->config->supported_group_list,
+                            Span(groups, num_groups))) {
+    clear_key_shares_if_invalid(ssl->config.get());
+    return 1;
+  }
+  return 0;
 }
 
 static bool ssl_str_to_group_ids(Array<uint16_t> *out_group_ids,
@@ -1951,7 +1990,11 @@
   if (!ssl->config) {
     return 0;
   }
-  return ssl_str_to_group_ids(&ssl->config->supported_group_list, groups);
+  if (ssl_str_to_group_ids(&ssl->config->supported_group_list, groups)) {
+    clear_key_shares_if_invalid(ssl->config.get());
+    return 1;
+  }
+  return 0;
 }
 
 uint16_t SSL_get_group_id(const SSL *ssl) {
@@ -1971,6 +2014,23 @@
   return ssl_group_id_to_nid(group_id);
 }
 
+int SSL_set1_client_key_shares(SSL *ssl, const uint16_t *group_ids,
+                               size_t num_group_ids) {
+  if (!ssl->config) {
+    return 0;
+  }
+  auto requested_key_shares = Span(group_ids, num_group_ids);
+  if (!validate_key_shares(requested_key_shares,
+                           ssl->config->supported_group_list)) {
+    return 0;
+  }
+
+  assert(requested_key_shares.size() <= kNumNamedGroups);
+  ssl->config->client_key_share_selections.emplace();
+  ssl->config->client_key_share_selections->CopyFrom(requested_key_shares);
+  return 1;
+}
+
 int SSL_CTX_set_tmp_dh(SSL_CTX *ctx, const DH *dh) { return 1; }
 
 int SSL_set_tmp_dh(SSL *ssl, const DH *dh) { return 1; }
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 96edb61..1b34c14 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -722,6 +722,146 @@
   }
 }
 
+TEST(SSLTest, SetClientKeyShares) {
+  const struct {
+    const char *description;
+    std::vector<uint16_t> supported_groups;
+    std::vector<uint16_t> key_shares;
+    bool expected_success;
+  } kTests[] = {
+      {
+          "Empty key shares with default supported groups",
+          {},
+          {},
+          true,
+      },
+      {
+          "Empty key shares with custom supported groups",
+          {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768},
+          {},
+          true,
+      },
+      {
+          "One key share matching default supported groups",
+          {},
+          {SSL_GROUP_X25519},
+          true,
+      },
+      {
+          "One key share matching custom supported groups",
+          {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768},
+          {SSL_GROUP_X25519},
+          true,
+      },
+      {
+          "Key share not in supported default groups",
+          {},
+          {SSL_GROUP_MLKEM1024},
+          false,
+      },
+      {
+          "Key share not in supported custom groups",
+          {SSL_GROUP_X25519, SSL_GROUP_SECP256R1},
+          {SSL_GROUP_X25519_MLKEM768},
+          false,
+      },
+      {
+          "Multiple key shares, in correct order",
+          {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+          {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768},
+          true,
+      },
+      {
+          "Multiple key shares, out of order",
+          {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+          {SSL_GROUP_X25519_MLKEM768, SSL_GROUP_X25519},
+          false,
+      },
+      {
+          "More than two key shares",
+          {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768,
+           SSL_GROUP_MLKEM1024},
+          {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+          true,
+      },
+      {
+          "Key shares cover all supported groups",
+          {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+          {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+          true,
+      },
+      {
+          "Multiple key shares, not all valid",
+          {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768},
+          {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+          false,
+      },
+      {
+          "Key shares contain duplicates",
+          {},
+          {SSL_GROUP_X25519, SSL_GROUP_X25519},
+          false,
+      },
+  };
+
+  for (const auto &t : kTests) {
+    SCOPED_TRACE(t.description);
+    bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+    ASSERT_TRUE(ctx);
+    bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get()));
+    ASSERT_TRUE(ssl);
+    ASSERT_FALSE(ssl->config->client_key_share_selections.has_value());
+
+    ASSERT_TRUE(SSL_set1_group_ids(ssl.get(), t.supported_groups.data(),
+                                   t.supported_groups.size()));
+    EXPECT_EQ(SSL_set1_client_key_shares(ssl.get(), t.key_shares.data(),
+                                         t.key_shares.size()),
+              t.expected_success);
+    if (t.expected_success) {
+      ASSERT_TRUE(ssl->config->client_key_share_selections.has_value());
+      EXPECT_THAT(ssl->config->client_key_share_selections.value(),
+                  ElementsAreArray(t.key_shares));
+    }
+  }
+}
+
+// Test the behavior that modifying the SSL's supported groups results in
+// clearing the previously set client key shares, iff the supported groups
+// become incompatible with the key shares.
+TEST(SSLTest, ClientKeySharesResetAfterChangingGroups) {
+  bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(ctx);
+  bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get()));
+  ASSERT_TRUE(ssl);
+  ASSERT_FALSE(ssl->config->client_key_share_selections.has_value());
+
+  // An initial groups list and key shares that are compatible.
+  const uint16_t kGroups1[] = {SSL_GROUP_X25519_MLKEM768, SSL_GROUP_X25519};
+  const uint16_t kKeyShares[] = {SSL_GROUP_X25519_MLKEM768, SSL_GROUP_X25519};
+  ASSERT_TRUE(
+      SSL_set1_group_ids(ssl.get(), kGroups1, std::size(kGroups1)));
+  ASSERT_TRUE(SSL_set1_client_key_shares(ssl.get(), kKeyShares,
+                                         std::size(kKeyShares)));
+  ASSERT_TRUE(ssl->config->client_key_share_selections.has_value());
+  EXPECT_EQ(ssl->config->client_key_share_selections->size(), 2u);
+
+  // A new groups list that is still compatible with the previously set key
+  // shares.
+  const uint16_t kGroups2[] = {SSL_GROUP_MLKEM1024, SSL_GROUP_X25519_MLKEM768,
+                               SSL_GROUP_X25519};
+  ASSERT_TRUE(
+      SSL_set1_group_ids(ssl.get(), kGroups2, std::size(kGroups2)));
+  ASSERT_TRUE(ssl->config->client_key_share_selections.has_value());
+  EXPECT_EQ(ssl->config->client_key_share_selections->size(), 2u);
+
+  // A new groups list that is no longer compatible with the previously set key
+  // shares.
+  const uint16_t kGroups3[] = {SSL_GROUP_MLKEM1024, SSL_GROUP_X25519};
+  ASSERT_TRUE(
+      SSL_set1_group_ids(ssl.get(), kGroups3, std::size(kGroups3)));
+  EXPECT_FALSE(ssl->config->client_key_share_selections.has_value());
+}
+
 // kOpenSSLSession is a serialized SSL_SESSION.
 static const char kOpenSSLSession[] =
     "MIIFqgIBAQICAwMEAsAvBCAG5Q1ndq4Yfmbeo1zwLkNRKmCXGdNgWvGT3cskV0yQ"
diff --git a/ssl/test/runner/pake_tests.go b/ssl/test/runner/pake_tests.go
index ddd2eb3..677a3a4 100644
--- a/ssl/test/runner/pake_tests.go
+++ b/ssl/test/runner/pake_tests.go
@@ -14,7 +14,10 @@
 
 package runner
 
-import "errors"
+import (
+	"errors"
+	"strconv"
+)
 
 func addPAKETests() {
 	spakeCredential := Credential{
@@ -231,17 +234,22 @@
 		shimCredentials: []*Credential{&spakeCredential},
 	})
 	testCases = append(testCases, testCase{
-		// A PAKE client will not offer key shares, so the client should
-		// reject a HelloRetryRequest requesting a different key share.
+		// A PAKE client will not offer key shares, even if explicitly configured,
+		// and the client should reject a HelloRetryRequest requesting a different
+		// key share.
 		name:     "PAKE-Client-HRRKeyShare",
 		testType: clientTest,
 		config: Config{
 			MinVersion: VersionTLS13,
 			Credential: &spakeCredential,
 			Bugs: ProtocolBugs{
+				ExpectedKeyShares:          []CurveID{},
 				SendHelloRetryRequestCurve: CurveX25519,
 			},
 		},
+		flags: []string{
+			"-key-shares", strconv.Itoa(int(CurveP256)),
+		},
 		shimCredentials:    []*Credential{&spakeCredential},
 		shouldFail:         true,
 		expectedError:      ":UNEXPECTED_EXTENSION:",
diff --git a/ssl/test/runner/tls13_tests.go b/ssl/test/runner/tls13_tests.go
index 2b08ec8..b9826eb 100644
--- a/ssl/test/runner/tls13_tests.go
+++ b/ssl/test/runner/tls13_tests.go
@@ -132,6 +132,175 @@
 	})
 
 	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "CustomKeyShares-TLS13",
+		config: Config{
+			MinVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				ExpectedKeyShares: []CurveID{CurveP256},
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveX25519)),
+			"-curves", strconv.Itoa(int(CurveP256)),
+			"-key-shares", strconv.Itoa(int(CurveP256)),
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "CustomKeyShares-DefaultSupportedGroups-TLS13",
+		config: Config{
+			MinVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				ExpectedKeyShares: []CurveID{CurveX25519},
+			},
+		},
+		flags: []string{
+			// Configure key shares without explicitly configuring supported groups.
+			"-key-shares", strconv.Itoa(int(CurveX25519)),
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "CustomKeyShares-Multiple-TLS13",
+		config: Config{
+			MinVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				ExpectedKeyShares: []CurveID{CurveX25519, CurveX25519MLKEM768},
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveX25519)),
+			"-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-curves", strconv.Itoa(int(CurveP256)),
+			// Predict the top 2 out of 3 supported curves.
+			"-key-shares", strconv.Itoa(int(CurveX25519)),
+			"-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-expect-curve-id", strconv.Itoa(int(CurveX25519)),
+			"-expect-no-hrr",
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "CustomKeyShares-MultipleNonContiguous-TLS13",
+		config: Config{
+			MinVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				ExpectedKeyShares: []CurveID{CurveX25519, CurveP256},
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveX25519)),
+			"-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-curves", strconv.Itoa(int(CurveP256)),
+			// Predict the first and last of 3 supported curves.
+			"-key-shares", strconv.Itoa(int(CurveX25519)),
+			"-key-shares", strconv.Itoa(int(CurveP256)),
+			"-expect-curve-id", strconv.Itoa(int(CurveX25519)),
+			"-expect-no-hrr",
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "CustomKeyShares-NotMostPreferred-HRR-TLS13",
+		config: Config{
+			MinVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				ExpectedKeyShares: []CurveID{CurveX25519MLKEM768, CurveP256},
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveX25519)),
+			"-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-curves", strconv.Itoa(int(CurveP256)),
+			// Predict some curves that we support, not including the top one.
+			"-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-key-shares", strconv.Itoa(int(CurveP256)),
+			"-expect-curve-id", strconv.Itoa(int(CurveX25519)),
+			// Check that we triggered a HelloRetryRequest.
+			"-expect-hrr",
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "CustomKeyShares-PeerSelectedLaterKeyShare-TLS13",
+		config: Config{
+			MinVersion:       VersionTLS13,
+			CurvePreferences: []CurveID{CurveP256},
+			Bugs: ProtocolBugs{
+				ExpectedKeyShares: []CurveID{CurveX25519MLKEM768, CurveP256},
+				// Make the server select one of the groups with a key_share, but not
+				// the most preferred one.
+				SendCurve: CurveP256,
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveX25519)),
+			"-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-curves", strconv.Itoa(int(CurveP256)),
+			"-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-key-shares", strconv.Itoa(int(CurveP256)),
+			"-expect-curve-id", strconv.Itoa(int(CurveP256)),
+			"-expect-no-hrr",
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "CustomKeyShares-All-TLS13",
+		config: Config{
+			MinVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				ExpectedKeyShares: []CurveID{
+					CurveP256,
+					CurveP384,
+					CurveP521,
+					CurveX25519,
+					CurveX25519Kyber768,
+					CurveX25519MLKEM768,
+					CurveMLKEM1024,
+				},
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveP256)),
+			"-curves", strconv.Itoa(int(CurveP384)),
+			"-curves", strconv.Itoa(int(CurveP521)),
+			"-curves", strconv.Itoa(int(CurveX25519)),
+			"-curves", strconv.Itoa(int(CurveX25519Kyber768)),
+			"-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-curves", strconv.Itoa(int(CurveMLKEM1024)),
+			"-key-shares", strconv.Itoa(int(CurveP256)),
+			"-key-shares", strconv.Itoa(int(CurveP384)),
+			"-key-shares", strconv.Itoa(int(CurveP521)),
+			"-key-shares", strconv.Itoa(int(CurveX25519)),
+			"-key-shares", strconv.Itoa(int(CurveX25519Kyber768)),
+			"-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-key-shares", strconv.Itoa(int(CurveMLKEM1024)),
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "CustomKeyShares-Empty-HRR-TLS13",
+		config: Config{
+			MinVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				ExpectedKeyShares: []CurveID{},
+			},
+		},
+		flags: []string{
+			"-no-key-shares",
+			"-expect-hrr",
+		},
+	})
+
+	testCases = append(testCases, testCase{
 		testType: serverTest,
 		name:     "SkipEarlyData-TLS13",
 		config: Config{
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index f973d9e..0cfdc37 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -53,12 +53,17 @@
 template <typename Config>
 struct Flag {
   const char *name;
+  // has_param, if true, causes the parser to look for a param value following
+  // this flag's name.
   bool has_param;
   // skip_handshaker, if true, causes this flag to be skipped when
   // forwarding flags to the handshaker. This should be used with flags
   // that only impact connecting to the runner.
   bool skip_handshaker;
-  // If |has_param| is false, |param| will be nullptr.
+  // set_param is called after parsing to interpret and set the result on
+  // `config`. If `has_param` is false for this flag, `param` will be nullptr.
+  // This function should return whether the param value (or lack thereof) was
+  // valid for this flag.
   std::function<bool(Config *config, const char *param)> set_param;
 };
 
@@ -168,6 +173,39 @@
                       }};
 }
 
+// Defines a flag which adds an integer param value to an optional vector of
+// integers.
+template <typename Config, typename T>
+Flag<Config> OptionalIntVectorFlag(const char *name,
+                                   std::optional<std::vector<T>> Config::*field,
+                                   bool skip_handshaker = false) {
+  return Flag<Config>{name, true, skip_handshaker,
+                      [=](Config *config, const char *param) -> bool {
+                        if (!(config->*field)) {
+                          (config->*field).emplace();
+                        }
+                        T value;
+                        if (!StringToInt(&value, param)) {
+                          return false;
+                        }
+                        (config->*field)->push_back(value);
+                        return true;
+                      }};
+}
+
+// Defines a flag which resets a std::optional field to its default constructed
+// value.
+template <typename Config, typename T>
+Flag<Config> OptionalDefaultInitFlag(const char *name,
+                                     std::optional<T> Config::*field,
+                                     bool skip_handshaker = false) {
+  return Flag<Config>{name, false, skip_handshaker,
+                      [=](Config *config, const char *) -> bool {
+                        (config->*field).emplace();
+                        return true;
+                      }};
+}
+
 template <typename Config>
 Flag<Config> StringFlag(const char *name, std::string Config::*field,
                         bool skip_handshaker = false) {
@@ -327,6 +365,8 @@
         IntVectorFlag("-expect-peer-verify-pref",
                       &TestConfig::expect_peer_verify_prefs),
         IntVectorFlag("-curves", &TestConfig::curves),
+        OptionalIntVectorFlag("-key-shares", &TestConfig::key_shares),
+        OptionalDefaultInitFlag("-no-key-shares", &TestConfig::key_shares),
         StringFlag("-trust-cert", &TestConfig::trust_cert),
         StringFlag("-expect-server-name", &TestConfig::expect_server_name),
         BoolFlag("-enable-ech-grease", &TestConfig::enable_ech_grease),
@@ -679,6 +719,9 @@
     if (!skip) {
       if (out != nullptr) {
         if (!flag->set_param(out, param)) {
+          if (!param) {
+            param = "(no parameter)";
+          }
           fprintf(stderr, "Invalid parameter for %s: %s\n", name, param);
           return false;
         }
@@ -687,6 +730,9 @@
         if (!flag->set_param(out_initial, param) ||
             !flag->set_param(out_resume, param) ||
             !flag->set_param(out_retry, param)) {
+          if (!param) {
+            param = "(no parameter)";
+          }
           fprintf(stderr, "Invalid parameter for %s: %s\n", name, param);
           return false;
         }
@@ -2481,6 +2527,11 @@
       !SSL_set1_group_ids(ssl.get(), curves.data(), curves.size())) {
     return nullptr;
   }
+  if (key_shares.has_value() &&
+      !SSL_set1_client_key_shares(ssl.get(), key_shares->data(),
+                                  key_shares->size())) {
+    return nullptr;
+  }
   if (initial_timeout_duration_ms > 0) {
     DTLSv1_set_initial_timeout_duration(ssl.get(), initial_timeout_duration_ms);
   }
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index de2f5c5..3b2f4e1 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -65,6 +65,7 @@
   std::vector<uint16_t> verify_prefs;
   std::vector<uint16_t> expect_peer_verify_prefs;
   std::vector<uint16_t> curves;
+  std::optional<std::vector<uint16_t>> key_shares;
   std::string key_file;
   std::string cert_file;
   std::string trust_cert;
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index dc417b9..a664321 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -253,7 +253,11 @@
   // The ECH extension, if present, was already parsed by
   // |check_ech_confirmation|.
   SSLExtension cookie(TLSEXT_TYPE_cookie),
-      key_share(TLSEXT_TYPE_key_share, !hs->key_share_bytes.empty()),
+      // If offering PAKE, we won't send key_share extensions and we should
+      // reject key_share from the peer. Otherwise, it is valid to have sent an
+      // empty key_share extension, and expect the HelloRetryRequest to contain
+      // a key_share.
+      key_share(TLSEXT_TYPE_key_share, !hs->pake_prover),
       supported_versions(TLSEXT_TYPE_supported_versions),
       ech_unused(TLSEXT_TYPE_encrypted_client_hello,
                  hs->selected_ech_config || hs->config->ech_grease_enabled);
@@ -286,8 +290,6 @@
   }
 
   if (key_share.present) {
-    // If offering PAKE, we won't send key_share extensions, in which case we
-    // would have rejected key_share from the peer.
     assert(!hs->pake_prover);
 
     uint16_t group_id;