Add API for caller to hint server's preferred key shares

This change introduces a new mechanism for clients to provide a "hint"
about which key exchange groups the server is likely to support. This is
meant to be used for draft-ietf-tls-key-share-prediction.

We will try to make a prediction based on the server's hint, and if a
group was predicted, it is sent as the key_share, overriding any
explicit key shares configured by the caller.

Bug: 437414371
Change-Id: Ie98d2a63e5bfb4f10683ca4f1f77e5439b39e256
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/81588
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Lily Chen <chlily@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index d2d6066..b66ad72 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -2672,6 +2672,27 @@
                                               const uint16_t *group_ids,
                                               size_t num_group_ids);
 
+// SSL_set1_server_supported_groups_hint, when |ssl| is a client, indicates that
+// the server is likely to support groups listed in |server_groups|, in order of
+// decreasing server preference. This function returns one on success and zero
+// on error. This may be used when receiving a server hint, such as described in
+// draft-ietf-tls-key-share-prediction.
+//
+// If called, |ssl| will try to predict the server's selected named group based
+// on |ssl|'s local preferences and |server_groups|. If it predicts a group, it
+// will then send an initial ClientHello with key_share extension containing
+// only this prediction. In this case, the prediction will supersede any
+// configuration from |SSL_set1_client_key_shares|. This is a convenience
+// function so that callers do not need to process the server preference list
+// themselves.
+//
+// Groups listed in |server_groups| should be identified by their TLS group IDs,
+// such as the |SSL_GROUP_*| constants. A server may implement groups not known
+// to BoringSSL, so |server_groups| may contain unrecognized group IDs. If so,
+// this function will ignore them.
+OPENSSL_EXPORT int SSL_set1_server_supported_groups_hint(
+    SSL *ssl, const uint16_t *server_groups, size_t num_server_groups);
+
 
 // Certificate verification.
 //
diff --git a/ssl/extensions.cc b/ssl/extensions.cc
index b6043e4..0820d92 100644
--- a/ssl/extensions.cc
+++ b/ssl/extensions.cc
@@ -2204,18 +2204,38 @@
     return false;
   }
 
-  InplaceVector<uint16_t, 2> default_key_shares;
+  std::optional<Span<const uint16_t>> selected_key_shares;
 
-  Span<const uint16_t> selected_key_shares;
-
-  // Determine the key shares to send.
+  // This is used for HelloRetryRequest, where we have already been given the
+  // correct group to use.
   if (override_group_id != 0) {
     assert(std::find(supported_group_list.begin(), supported_group_list.end(),
                      override_group_id) != supported_group_list.end());
-    selected_key_shares = Span(&override_group_id, 1u);
-  } else if (ssl->config->client_key_share_selections.has_value()) {
-    selected_key_shares = *(ssl->config->client_key_share_selections);
-  } else {
+    selected_key_shares.emplace(&override_group_id, 1u);
+  }
+
+  // First try to predict the most preferred (by our preference order) group
+  // that was hinted by the server.
+  const auto &server_hint_list = ssl->config->server_supported_groups_hint;
+  if (!selected_key_shares.has_value() && !server_hint_list.empty()) {
+    for (const uint16_t &supported_group : supported_group_list) {
+      if (std::find(server_hint_list.begin(), server_hint_list.end(),
+                    supported_group) != server_hint_list.end()) {
+        selected_key_shares.emplace(&supported_group, 1u);
+        break;
+      }
+    }
+  }
+
+  // Otherwise, try to use explicitly configured key shares from the caller.
+  if (!selected_key_shares.has_value() &&
+      ssl->config->client_key_share_selections.has_value()) {
+    selected_key_shares.emplace(*ssl->config->client_key_share_selections);
+  }
+
+  // Run the default selection if we don't have anything better.
+  InplaceVector<uint16_t, 2> default_key_shares;
+  if (!selected_key_shares.has_value()) {
     // By default, predict the most preferred group.
     if (!default_key_shares.TryPushBack(supported_group_list[0])) {
       return false;
@@ -2233,9 +2253,11 @@
       assert(default_key_shares[1] != default_key_shares[0]);
       break;
     }
-    selected_key_shares = default_key_shares;
+    selected_key_shares.emplace(default_key_shares);
   }
 
+  assert(selected_key_shares.has_value());
+
   bssl::ScopedCBB cbb;
   if (!CBB_init(cbb.get(), 64)) {
     return false;
@@ -2251,7 +2273,7 @@
   }
 
   CBB key_exchange;
-  for (const uint16_t group_id : selected_key_shares) {
+  for (const uint16_t group_id : *selected_key_shares) {
     UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(group_id);
     if (key_share == nullptr ||                                    //
         !CBB_add_u16(cbb.get(), group_id) ||                       //
diff --git a/ssl/internal.h b/ssl/internal.h
index 8486f31..cd5b67d 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2133,12 +2133,15 @@
 // - If |override_group_id| is non-zero, it offers a single key share of the
 //   specified group.
 //
-// - If key shares were previously specified by the caller via
-//   |SSL_set_client_key_shares|, those are used.
+// - If a group can be predicted on the basis of a server hint set via
+//   |SSL_set1_server_supported_groups_hint|, a single key share of that group
+//   is sent.
 //
-// - If |override_group_id| is zero and no selections were made by the caller
-//   via |SSL_set_client_key_shares|, it selects the first supported group and
-//   may select a second if at most one of the two is a post-quantum group.
+// - If any number of key shares (including zero) were previously specified by
+//   the caller via |SSL_set1_client_key_shares|, those are used.
+//
+// - Otherwise, it selects the first supported group and may select a second if
+//   at most one of the two is a post-quantum group.
 //
 // GREASE will be included if enabled, when |override_group_id| is zero.
 bool ssl_setup_key_shares(SSL_HANDSHAKE *hs, uint16_t override_group_id);
@@ -3291,12 +3294,16 @@
   // For a client, this may contain a subsequence of the group IDs in
   // |suppported_group_list|, which gives the groups for which key shares should
   // be sent in the client's key_share extension. This is non-nullopt iff
-  // |SSL_set_client_key_shares| was successfully called to configure key
+  // |SSL_set1_client_key_shares| was successfully called to configure key
   // shares. If non-nullopt, these groups are in the same order as they appear
   // in |supported_group_list|, and may not contain duplicates.
   std::optional<InplaceVector<uint16_t, kNumNamedGroups>>
       client_key_share_selections;
 
+  // For a client, this contains a list of groups believed to be supported by
+  // the server, in server preference order.
+  Array<uint16_t> server_supported_groups_hint;
+
   // channel_id_private is the client's Channel ID private key, or null if
   // Channel ID should not be offered on this connection.
   UniquePtr<EVP_PKEY> channel_id_private;
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 532744f..bbe428f 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -2031,6 +2031,16 @@
   return 1;
 }
 
+int SSL_set1_server_supported_groups_hint(SSL *ssl,
+                                          const uint16_t *server_groups,
+                                          size_t num_server_groups) {
+  if (!ssl->config) {
+    return 0;
+  }
+  auto span = Span(server_groups, num_server_groups);
+  return ssl->config->server_supported_groups_hint.CopyFrom(span);
+}
+
 int SSL_CTX_set_tmp_dh(SSL_CTX *ctx, const DH *dh) { return 1; }
 
 int SSL_set_tmp_dh(SSL *ssl, const DH *dh) { return 1; }
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index f34a531..efb4bce 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -862,6 +862,139 @@
   EXPECT_FALSE(ssl->config->client_key_share_selections.has_value());
 }
 
+TEST(SSLTest, ServerSupportedGroupsHint) {
+  // List of client's supported groups used for these test cases.
+  const uint16_t kSupportedGroups[] = {
+      SSL_GROUP_X25519_MLKEM768,  //
+      SSL_GROUP_MLKEM1024,        //
+      SSL_GROUP_X25519,           //
+      SSL_GROUP_SECP256R1,        //
+  };
+
+  // By default, the first post-quantum and first classical groups are chosen.
+  std::vector<uint16_t> kDefaultKeyShares = {SSL_GROUP_X25519_MLKEM768,
+                                             SSL_GROUP_X25519};
+
+  const struct {
+    const char *description;
+    std::vector<uint16_t> server_hinted_groups;
+    std::vector<uint16_t> expected_key_shares;
+  } kTests[] = {
+      {
+          "Empty hint (defaults chosen)",
+          {},
+          kDefaultKeyShares,
+      },
+      {
+          "Hint one group, supported by client",
+          {SSL_GROUP_SECP256R1},
+          {SSL_GROUP_SECP256R1},
+      },
+      {
+          "Hint one group, not supported by client",
+          {SSL_GROUP_X25519_KYBER768_DRAFT00},
+          kDefaultKeyShares,
+      },
+      {
+          "Hint two groups, both supported by client in same order",
+          {SSL_GROUP_X25519_MLKEM768, SSL_GROUP_MLKEM1024},
+          {SSL_GROUP_X25519_MLKEM768},
+      },
+      {
+          "Hint two groups, both supported by client in different order",
+          {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768},
+          // The key share will be determined by the client's preference order,
+          // not the server hint's order.
+          {SSL_GROUP_X25519_MLKEM768},
+      },
+      {
+          "Hint two groups, supported and non-supported by client",
+          {SSL_GROUP_MLKEM1024, SSL_GROUP_SECP384R1},
+          {SSL_GROUP_MLKEM1024},
+      },
+      {
+          "Hint two groups, non-supported and supported by client",
+          {SSL_GROUP_SECP384R1, SSL_GROUP_MLKEM1024},
+          {SSL_GROUP_MLKEM1024},
+      },
+      {
+          "Hint unrecognized group",
+          {0x1234},
+          kDefaultKeyShares,
+      },
+  };
+
+  for (const auto &t : kTests) {
+    SCOPED_TRACE(t.description);
+    bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+    ASSERT_TRUE(ctx);
+    ASSERT_TRUE(SSL_CTX_set1_group_ids(ctx.get(), kSupportedGroups,
+                                       std::size(kSupportedGroups)));
+    bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get()));
+    ASSERT_TRUE(ssl);
+    ASSERT_TRUE(SSL_set1_server_supported_groups_hint(
+        ssl.get(), t.server_hinted_groups.data(),
+        t.server_hinted_groups.size()));
+    ASSERT_TRUE(SSL_connect(ssl.get()));
+
+    std::vector<uint16_t> key_shares;
+    for (const auto &key_share : ssl->s3->hs->key_shares) {
+      key_shares.push_back(key_share->GroupID());
+    }
+    EXPECT_THAT(key_shares, ElementsAreArray(t.expected_key_shares));
+  }
+}
+
+TEST(SSLTest, ServerHintOverridesClientKeyShareSelections) {
+  bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(ctx);
+  bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get()));
+  ASSERT_TRUE(ssl);
+
+  const uint16_t kGroups[] = {SSL_GROUP_SECP256R1, SSL_GROUP_X25519};
+  ASSERT_TRUE(SSL_set1_group_ids(ssl.get(), kGroups, std::size(kGroups)));
+
+  const uint16_t kKeyShares[] = {SSL_GROUP_SECP256R1};
+  ASSERT_TRUE(
+      SSL_set1_client_key_shares(ssl.get(), kKeyShares, std::size(kKeyShares)));
+  ASSERT_TRUE(ssl->config->client_key_share_selections.has_value());
+  EXPECT_THAT(ssl->config->client_key_share_selections.value(),
+              ElementsAreArray(kKeyShares));
+  const uint16_t kServerHint[] = {SSL_GROUP_X25519};
+  ASSERT_TRUE(SSL_set1_server_supported_groups_hint(ssl.get(), kServerHint,
+                                                    std::size(kServerHint)));
+  EXPECT_THAT(ssl->config->server_supported_groups_hint,
+              ElementsAreArray(kServerHint));
+
+  // The group predicted based on the server hint should win.
+  ASSERT_TRUE(SSL_connect(ssl.get()));
+  ASSERT_EQ(ssl->s3->hs->key_shares.size(), 1u);
+  EXPECT_EQ(kServerHint[0], ssl->s3->hs->key_shares[0]->GroupID());
+}
+
+TEST(SSLTest, ServerHintOverridesEmptyClientKeyShareSelections) {
+  bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(ctx);
+  bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get()));
+  ASSERT_TRUE(ssl);
+
+  const uint16_t kGroups[] = {SSL_GROUP_SECP256R1, SSL_GROUP_X25519};
+  ASSERT_TRUE(SSL_set1_group_ids(ssl.get(), kGroups, std::size(kGroups)));
+
+  ASSERT_TRUE(SSL_set1_client_key_shares(ssl.get(), nullptr, 0));
+  EXPECT_TRUE(ssl->config->client_key_share_selections->empty());
+  const uint16_t kServerHint[] = {SSL_GROUP_X25519};
+  ASSERT_TRUE(SSL_set1_server_supported_groups_hint(ssl.get(), kServerHint,
+                                                    std::size(kServerHint)));
+  EXPECT_THAT(ssl->config->server_supported_groups_hint,
+              ElementsAreArray(kServerHint));
+
+  // The group predicted based on the server hint should win.
+  ASSERT_TRUE(SSL_connect(ssl.get()));
+  ASSERT_EQ(ssl->s3->hs->key_shares.size(), 1u);
+  EXPECT_EQ(kServerHint[0], ssl->s3->hs->key_shares[0]->GroupID());
+}
+
 // kOpenSSLSession is a serialized SSL_SESSION.
 static const char kOpenSSLSession[] =
     "MIIFqgIBAQICAwMEAsAvBCAG5Q1ndq4Yfmbeo1zwLkNRKmCXGdNgWvGT3cskV0yQ"
diff --git a/ssl/test/runner/tls13_tests.go b/ssl/test/runner/tls13_tests.go
index b9826eb..e4af165 100644
--- a/ssl/test/runner/tls13_tests.go
+++ b/ssl/test/runner/tls13_tests.go
@@ -301,6 +301,137 @@
 	})
 
 	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "KeyShareWithServerHint-TLS13",
+		config: Config{
+			MinVersion:       VersionTLS13,
+			CurvePreferences: []CurveID{CurveX25519MLKEM768, CurveX25519},
+			Bugs: ProtocolBugs{
+				// The prediction is made based on the client's preference order.
+				ExpectedKeyShares: []CurveID{CurveX25519},
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveMLKEM1024)),
+			"-curves", strconv.Itoa(int(CurveX25519)),
+			"-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-server-supported-groups-hint", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-server-supported-groups-hint", strconv.Itoa(int(CurveX25519)),
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "KeyShareWithServerHint-UnrecognizedGroupIgnored-TLS13",
+		config: Config{
+			MinVersion:       VersionTLS13,
+			CurvePreferences: []CurveID{bogusCurve, CurveX25519},
+			Bugs: ProtocolBugs{
+				ExpectedKeyShares: []CurveID{CurveX25519},
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveMLKEM1024)),
+			"-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-curves", strconv.Itoa(int(CurveX25519)),
+			"-server-supported-groups-hint", strconv.Itoa(int(bogusCurve)),
+			"-server-supported-groups-hint", strconv.Itoa(int(CurveX25519)),
+			"-expect-curve-id", strconv.Itoa(int(CurveX25519)),
+			"-expect-no-hrr",
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "KeyShareWithServerHint-OverridesExplicitKeyShare-TLS13",
+		config: Config{
+			MinVersion:       VersionTLS13,
+			CurvePreferences: []CurveID{CurveX25519MLKEM768, CurveMLKEM1024},
+			Bugs: ProtocolBugs{
+				// The group predicted from the server hint is used over explicitly
+				// configured key shares.
+				ExpectedKeyShares: []CurveID{CurveX25519MLKEM768},
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-curves", strconv.Itoa(int(CurveMLKEM1024)),
+			"-key-shares", strconv.Itoa(int(CurveMLKEM1024)),
+			"-server-supported-groups-hint", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-server-supported-groups-hint", strconv.Itoa(int(CurveMLKEM1024)),
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "KeyShareWithServerHint-OverridesExplicitEmptyKeyShare-TLS13",
+		config: Config{
+			MinVersion:       VersionTLS13,
+			CurvePreferences: []CurveID{CurveX25519MLKEM768, CurveMLKEM1024},
+			Bugs: ProtocolBugs{
+				// The group predicted from the server hint is used over explicitly
+				// configured empty key shares.
+				ExpectedKeyShares: []CurveID{CurveX25519MLKEM768},
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-curves", strconv.Itoa(int(CurveMLKEM1024)),
+			"-no-key-shares",
+			"-server-supported-groups-hint", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-server-supported-groups-hint", strconv.Itoa(int(CurveMLKEM1024)),
+		},
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "KeyShareWithServerHint-FallBackToExplicitKeyShare-TLS13",
+		config: Config{
+			MinVersion:       VersionTLS13,
+			CurvePreferences: []CurveID{CurveP256},
+			Bugs: ProtocolBugs{
+				// No group was predicted from the server hint, so the explicitly
+				// configured key shares are used.
+				ExpectedKeyShares: []CurveID{CurveX25519MLKEM768, CurveX25519},
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveMLKEM1024)),
+			"-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-curves", strconv.Itoa(int(CurveX25519)),
+			"-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-key-shares", strconv.Itoa(int(CurveX25519)),
+			"-server-supported-groups-hint", strconv.Itoa(int(CurveP256)),
+		},
+		shouldFail:    true,
+		expectedError: ":HANDSHAKE_FAILURE_ON_CLIENT_HELLO:",
+	})
+
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "KeyShareWithServerHint-FallBackToExplicitEmptyKeyShare-TLS13",
+		config: Config{
+			MinVersion:       VersionTLS13,
+			CurvePreferences: []CurveID{CurveP256},
+			Bugs: ProtocolBugs{
+				// No group was predicted from the server hint, so the explicitly
+				// configured key shares are used (in this case, empty to request HRR).
+				ExpectedKeyShares: []CurveID{},
+			},
+		},
+		flags: []string{
+			"-curves", strconv.Itoa(int(CurveMLKEM1024)),
+			"-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+			"-curves", strconv.Itoa(int(CurveX25519)),
+			"-no-key-shares",
+			"-server-supported-groups-hint", strconv.Itoa(int(CurveP256)),
+			"-expect-hrr",
+		},
+		shouldFail:    true,
+		expectedError: ":HANDSHAKE_FAILURE_ON_CLIENT_HELLO:",
+	})
+
+	testCases = append(testCases, testCase{
 		testType: serverTest,
 		name:     "SkipEarlyData-TLS13",
 		config: Config{
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index 0cfdc37..bce633a 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -367,6 +367,8 @@
         IntVectorFlag("-curves", &TestConfig::curves),
         OptionalIntVectorFlag("-key-shares", &TestConfig::key_shares),
         OptionalDefaultInitFlag("-no-key-shares", &TestConfig::key_shares),
+        IntVectorFlag("-server-supported-groups-hint",
+                      &TestConfig::server_supported_groups_hint),
         StringFlag("-trust-cert", &TestConfig::trust_cert),
         StringFlag("-expect-server-name", &TestConfig::expect_server_name),
         BoolFlag("-enable-ech-grease", &TestConfig::enable_ech_grease),
@@ -2532,6 +2534,12 @@
                                   key_shares->size())) {
     return nullptr;
   }
+  if (!server_supported_groups_hint.empty() &&
+      !SSL_set1_server_supported_groups_hint(
+          ssl.get(), server_supported_groups_hint.data(),
+          server_supported_groups_hint.size())) {
+    return nullptr;
+  }
   if (initial_timeout_duration_ms > 0) {
     DTLSv1_set_initial_timeout_duration(ssl.get(), initial_timeout_duration_ms);
   }
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index 3b2f4e1..9745de0 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -66,6 +66,7 @@
   std::vector<uint16_t> expect_peer_verify_prefs;
   std::vector<uint16_t> curves;
   std::optional<std::vector<uint16_t>> key_shares;
+  std::vector<uint16_t> server_supported_groups_hint;
   std::string key_file;
   std::string cert_file;
   std::string trust_cert;