Add API for caller to hint server's preferred key shares
This change introduces a new mechanism for clients to provide a "hint"
about which key exchange groups the server is likely to support. This is
meant to be used for draft-ietf-tls-key-share-prediction.
We will try to make a prediction based on the server's hint, and if a
group was predicted, it is sent as the key_share, overriding any
explicit key shares configured by the caller.
Bug: 437414371
Change-Id: Ie98d2a63e5bfb4f10683ca4f1f77e5439b39e256
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/81588
Reviewed-by: David Benjamin <davidben@google.com>
Commit-Queue: Lily Chen <chlily@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index d2d6066..b66ad72 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -2672,6 +2672,27 @@
const uint16_t *group_ids,
size_t num_group_ids);
+// SSL_set1_server_supported_groups_hint, when |ssl| is a client, indicates that
+// the server is likely to support groups listed in |server_groups|, in order of
+// decreasing server preference. This function returns one on success and zero
+// on error. This may be used when receiving a server hint, such as described in
+// draft-ietf-tls-key-share-prediction.
+//
+// If called, |ssl| will try to predict the server's selected named group based
+// on |ssl|'s local preferences and |server_groups|. If it predicts a group, it
+// will then send an initial ClientHello with key_share extension containing
+// only this prediction. In this case, the prediction will supersede any
+// configuration from |SSL_set1_client_key_shares|. This is a convenience
+// function so that callers do not need to process the server preference list
+// themselves.
+//
+// Groups listed in |server_groups| should be identified by their TLS group IDs,
+// such as the |SSL_GROUP_*| constants. A server may implement groups not known
+// to BoringSSL, so |server_groups| may contain unrecognized group IDs. If so,
+// this function will ignore them.
+OPENSSL_EXPORT int SSL_set1_server_supported_groups_hint(
+ SSL *ssl, const uint16_t *server_groups, size_t num_server_groups);
+
// Certificate verification.
//
diff --git a/ssl/extensions.cc b/ssl/extensions.cc
index b6043e4..0820d92 100644
--- a/ssl/extensions.cc
+++ b/ssl/extensions.cc
@@ -2204,18 +2204,38 @@
return false;
}
- InplaceVector<uint16_t, 2> default_key_shares;
+ std::optional<Span<const uint16_t>> selected_key_shares;
- Span<const uint16_t> selected_key_shares;
-
- // Determine the key shares to send.
+ // This is used for HelloRetryRequest, where we have already been given the
+ // correct group to use.
if (override_group_id != 0) {
assert(std::find(supported_group_list.begin(), supported_group_list.end(),
override_group_id) != supported_group_list.end());
- selected_key_shares = Span(&override_group_id, 1u);
- } else if (ssl->config->client_key_share_selections.has_value()) {
- selected_key_shares = *(ssl->config->client_key_share_selections);
- } else {
+ selected_key_shares.emplace(&override_group_id, 1u);
+ }
+
+ // First try to predict the most preferred (by our preference order) group
+ // that was hinted by the server.
+ const auto &server_hint_list = ssl->config->server_supported_groups_hint;
+ if (!selected_key_shares.has_value() && !server_hint_list.empty()) {
+ for (const uint16_t &supported_group : supported_group_list) {
+ if (std::find(server_hint_list.begin(), server_hint_list.end(),
+ supported_group) != server_hint_list.end()) {
+ selected_key_shares.emplace(&supported_group, 1u);
+ break;
+ }
+ }
+ }
+
+ // Otherwise, try to use explicitly configured key shares from the caller.
+ if (!selected_key_shares.has_value() &&
+ ssl->config->client_key_share_selections.has_value()) {
+ selected_key_shares.emplace(*ssl->config->client_key_share_selections);
+ }
+
+ // Run the default selection if we don't have anything better.
+ InplaceVector<uint16_t, 2> default_key_shares;
+ if (!selected_key_shares.has_value()) {
// By default, predict the most preferred group.
if (!default_key_shares.TryPushBack(supported_group_list[0])) {
return false;
@@ -2233,9 +2253,11 @@
assert(default_key_shares[1] != default_key_shares[0]);
break;
}
- selected_key_shares = default_key_shares;
+ selected_key_shares.emplace(default_key_shares);
}
+ assert(selected_key_shares.has_value());
+
bssl::ScopedCBB cbb;
if (!CBB_init(cbb.get(), 64)) {
return false;
@@ -2251,7 +2273,7 @@
}
CBB key_exchange;
- for (const uint16_t group_id : selected_key_shares) {
+ for (const uint16_t group_id : *selected_key_shares) {
UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(group_id);
if (key_share == nullptr || //
!CBB_add_u16(cbb.get(), group_id) || //
diff --git a/ssl/internal.h b/ssl/internal.h
index 8486f31..cd5b67d 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2133,12 +2133,15 @@
// - If |override_group_id| is non-zero, it offers a single key share of the
// specified group.
//
-// - If key shares were previously specified by the caller via
-// |SSL_set_client_key_shares|, those are used.
+// - If a group can be predicted on the basis of a server hint set via
+// |SSL_set1_server_supported_groups_hint|, a single key share of that group
+// is sent.
//
-// - If |override_group_id| is zero and no selections were made by the caller
-// via |SSL_set_client_key_shares|, it selects the first supported group and
-// may select a second if at most one of the two is a post-quantum group.
+// - If any number of key shares (including zero) were previously specified by
+// the caller via |SSL_set1_client_key_shares|, those are used.
+//
+// - Otherwise, it selects the first supported group and may select a second if
+// at most one of the two is a post-quantum group.
//
// GREASE will be included if enabled, when |override_group_id| is zero.
bool ssl_setup_key_shares(SSL_HANDSHAKE *hs, uint16_t override_group_id);
@@ -3291,12 +3294,16 @@
// For a client, this may contain a subsequence of the group IDs in
// |suppported_group_list|, which gives the groups for which key shares should
// be sent in the client's key_share extension. This is non-nullopt iff
- // |SSL_set_client_key_shares| was successfully called to configure key
+ // |SSL_set1_client_key_shares| was successfully called to configure key
// shares. If non-nullopt, these groups are in the same order as they appear
// in |supported_group_list|, and may not contain duplicates.
std::optional<InplaceVector<uint16_t, kNumNamedGroups>>
client_key_share_selections;
+ // For a client, this contains a list of groups believed to be supported by
+ // the server, in server preference order.
+ Array<uint16_t> server_supported_groups_hint;
+
// channel_id_private is the client's Channel ID private key, or null if
// Channel ID should not be offered on this connection.
UniquePtr<EVP_PKEY> channel_id_private;
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 532744f..bbe428f 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -2031,6 +2031,16 @@
return 1;
}
+int SSL_set1_server_supported_groups_hint(SSL *ssl,
+ const uint16_t *server_groups,
+ size_t num_server_groups) {
+ if (!ssl->config) {
+ return 0;
+ }
+ auto span = Span(server_groups, num_server_groups);
+ return ssl->config->server_supported_groups_hint.CopyFrom(span);
+}
+
int SSL_CTX_set_tmp_dh(SSL_CTX *ctx, const DH *dh) { return 1; }
int SSL_set_tmp_dh(SSL *ssl, const DH *dh) { return 1; }
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index f34a531..efb4bce 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -862,6 +862,139 @@
EXPECT_FALSE(ssl->config->client_key_share_selections.has_value());
}
+TEST(SSLTest, ServerSupportedGroupsHint) {
+ // List of client's supported groups used for these test cases.
+ const uint16_t kSupportedGroups[] = {
+ SSL_GROUP_X25519_MLKEM768, //
+ SSL_GROUP_MLKEM1024, //
+ SSL_GROUP_X25519, //
+ SSL_GROUP_SECP256R1, //
+ };
+
+ // By default, the first post-quantum and first classical groups are chosen.
+ std::vector<uint16_t> kDefaultKeyShares = {SSL_GROUP_X25519_MLKEM768,
+ SSL_GROUP_X25519};
+
+ const struct {
+ const char *description;
+ std::vector<uint16_t> server_hinted_groups;
+ std::vector<uint16_t> expected_key_shares;
+ } kTests[] = {
+ {
+ "Empty hint (defaults chosen)",
+ {},
+ kDefaultKeyShares,
+ },
+ {
+ "Hint one group, supported by client",
+ {SSL_GROUP_SECP256R1},
+ {SSL_GROUP_SECP256R1},
+ },
+ {
+ "Hint one group, not supported by client",
+ {SSL_GROUP_X25519_KYBER768_DRAFT00},
+ kDefaultKeyShares,
+ },
+ {
+ "Hint two groups, both supported by client in same order",
+ {SSL_GROUP_X25519_MLKEM768, SSL_GROUP_MLKEM1024},
+ {SSL_GROUP_X25519_MLKEM768},
+ },
+ {
+ "Hint two groups, both supported by client in different order",
+ {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768},
+ // The key share will be determined by the client's preference order,
+ // not the server hint's order.
+ {SSL_GROUP_X25519_MLKEM768},
+ },
+ {
+ "Hint two groups, supported and non-supported by client",
+ {SSL_GROUP_MLKEM1024, SSL_GROUP_SECP384R1},
+ {SSL_GROUP_MLKEM1024},
+ },
+ {
+ "Hint two groups, non-supported and supported by client",
+ {SSL_GROUP_SECP384R1, SSL_GROUP_MLKEM1024},
+ {SSL_GROUP_MLKEM1024},
+ },
+ {
+ "Hint unrecognized group",
+ {0x1234},
+ kDefaultKeyShares,
+ },
+ };
+
+ for (const auto &t : kTests) {
+ SCOPED_TRACE(t.description);
+ bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+ ASSERT_TRUE(ctx);
+ ASSERT_TRUE(SSL_CTX_set1_group_ids(ctx.get(), kSupportedGroups,
+ std::size(kSupportedGroups)));
+ bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get()));
+ ASSERT_TRUE(ssl);
+ ASSERT_TRUE(SSL_set1_server_supported_groups_hint(
+ ssl.get(), t.server_hinted_groups.data(),
+ t.server_hinted_groups.size()));
+ ASSERT_TRUE(SSL_connect(ssl.get()));
+
+ std::vector<uint16_t> key_shares;
+ for (const auto &key_share : ssl->s3->hs->key_shares) {
+ key_shares.push_back(key_share->GroupID());
+ }
+ EXPECT_THAT(key_shares, ElementsAreArray(t.expected_key_shares));
+ }
+}
+
+TEST(SSLTest, ServerHintOverridesClientKeyShareSelections) {
+ bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+ ASSERT_TRUE(ctx);
+ bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get()));
+ ASSERT_TRUE(ssl);
+
+ const uint16_t kGroups[] = {SSL_GROUP_SECP256R1, SSL_GROUP_X25519};
+ ASSERT_TRUE(SSL_set1_group_ids(ssl.get(), kGroups, std::size(kGroups)));
+
+ const uint16_t kKeyShares[] = {SSL_GROUP_SECP256R1};
+ ASSERT_TRUE(
+ SSL_set1_client_key_shares(ssl.get(), kKeyShares, std::size(kKeyShares)));
+ ASSERT_TRUE(ssl->config->client_key_share_selections.has_value());
+ EXPECT_THAT(ssl->config->client_key_share_selections.value(),
+ ElementsAreArray(kKeyShares));
+ const uint16_t kServerHint[] = {SSL_GROUP_X25519};
+ ASSERT_TRUE(SSL_set1_server_supported_groups_hint(ssl.get(), kServerHint,
+ std::size(kServerHint)));
+ EXPECT_THAT(ssl->config->server_supported_groups_hint,
+ ElementsAreArray(kServerHint));
+
+ // The group predicted based on the server hint should win.
+ ASSERT_TRUE(SSL_connect(ssl.get()));
+ ASSERT_EQ(ssl->s3->hs->key_shares.size(), 1u);
+ EXPECT_EQ(kServerHint[0], ssl->s3->hs->key_shares[0]->GroupID());
+}
+
+TEST(SSLTest, ServerHintOverridesEmptyClientKeyShareSelections) {
+ bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+ ASSERT_TRUE(ctx);
+ bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get()));
+ ASSERT_TRUE(ssl);
+
+ const uint16_t kGroups[] = {SSL_GROUP_SECP256R1, SSL_GROUP_X25519};
+ ASSERT_TRUE(SSL_set1_group_ids(ssl.get(), kGroups, std::size(kGroups)));
+
+ ASSERT_TRUE(SSL_set1_client_key_shares(ssl.get(), nullptr, 0));
+ EXPECT_TRUE(ssl->config->client_key_share_selections->empty());
+ const uint16_t kServerHint[] = {SSL_GROUP_X25519};
+ ASSERT_TRUE(SSL_set1_server_supported_groups_hint(ssl.get(), kServerHint,
+ std::size(kServerHint)));
+ EXPECT_THAT(ssl->config->server_supported_groups_hint,
+ ElementsAreArray(kServerHint));
+
+ // The group predicted based on the server hint should win.
+ ASSERT_TRUE(SSL_connect(ssl.get()));
+ ASSERT_EQ(ssl->s3->hs->key_shares.size(), 1u);
+ EXPECT_EQ(kServerHint[0], ssl->s3->hs->key_shares[0]->GroupID());
+}
+
// kOpenSSLSession is a serialized SSL_SESSION.
static const char kOpenSSLSession[] =
"MIIFqgIBAQICAwMEAsAvBCAG5Q1ndq4Yfmbeo1zwLkNRKmCXGdNgWvGT3cskV0yQ"
diff --git a/ssl/test/runner/tls13_tests.go b/ssl/test/runner/tls13_tests.go
index b9826eb..e4af165 100644
--- a/ssl/test/runner/tls13_tests.go
+++ b/ssl/test/runner/tls13_tests.go
@@ -301,6 +301,137 @@
})
testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "KeyShareWithServerHint-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ CurvePreferences: []CurveID{CurveX25519MLKEM768, CurveX25519},
+ Bugs: ProtocolBugs{
+ // The prediction is made based on the client's preference order.
+ ExpectedKeyShares: []CurveID{CurveX25519},
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveMLKEM1024)),
+ "-curves", strconv.Itoa(int(CurveX25519)),
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-server-supported-groups-hint", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-server-supported-groups-hint", strconv.Itoa(int(CurveX25519)),
+ },
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "KeyShareWithServerHint-UnrecognizedGroupIgnored-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ CurvePreferences: []CurveID{bogusCurve, CurveX25519},
+ Bugs: ProtocolBugs{
+ ExpectedKeyShares: []CurveID{CurveX25519},
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveMLKEM1024)),
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-curves", strconv.Itoa(int(CurveX25519)),
+ "-server-supported-groups-hint", strconv.Itoa(int(bogusCurve)),
+ "-server-supported-groups-hint", strconv.Itoa(int(CurveX25519)),
+ "-expect-curve-id", strconv.Itoa(int(CurveX25519)),
+ "-expect-no-hrr",
+ },
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "KeyShareWithServerHint-OverridesExplicitKeyShare-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ CurvePreferences: []CurveID{CurveX25519MLKEM768, CurveMLKEM1024},
+ Bugs: ProtocolBugs{
+ // The group predicted from the server hint is used over explicitly
+ // configured key shares.
+ ExpectedKeyShares: []CurveID{CurveX25519MLKEM768},
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-curves", strconv.Itoa(int(CurveMLKEM1024)),
+ "-key-shares", strconv.Itoa(int(CurveMLKEM1024)),
+ "-server-supported-groups-hint", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-server-supported-groups-hint", strconv.Itoa(int(CurveMLKEM1024)),
+ },
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "KeyShareWithServerHint-OverridesExplicitEmptyKeyShare-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ CurvePreferences: []CurveID{CurveX25519MLKEM768, CurveMLKEM1024},
+ Bugs: ProtocolBugs{
+ // The group predicted from the server hint is used over explicitly
+ // configured empty key shares.
+ ExpectedKeyShares: []CurveID{CurveX25519MLKEM768},
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-curves", strconv.Itoa(int(CurveMLKEM1024)),
+ "-no-key-shares",
+ "-server-supported-groups-hint", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-server-supported-groups-hint", strconv.Itoa(int(CurveMLKEM1024)),
+ },
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "KeyShareWithServerHint-FallBackToExplicitKeyShare-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ CurvePreferences: []CurveID{CurveP256},
+ Bugs: ProtocolBugs{
+ // No group was predicted from the server hint, so the explicitly
+ // configured key shares are used.
+ ExpectedKeyShares: []CurveID{CurveX25519MLKEM768, CurveX25519},
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveMLKEM1024)),
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-curves", strconv.Itoa(int(CurveX25519)),
+ "-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-key-shares", strconv.Itoa(int(CurveX25519)),
+ "-server-supported-groups-hint", strconv.Itoa(int(CurveP256)),
+ },
+ shouldFail: true,
+ expectedError: ":HANDSHAKE_FAILURE_ON_CLIENT_HELLO:",
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "KeyShareWithServerHint-FallBackToExplicitEmptyKeyShare-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ CurvePreferences: []CurveID{CurveP256},
+ Bugs: ProtocolBugs{
+ // No group was predicted from the server hint, so the explicitly
+ // configured key shares are used (in this case, empty to request HRR).
+ ExpectedKeyShares: []CurveID{},
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveMLKEM1024)),
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-curves", strconv.Itoa(int(CurveX25519)),
+ "-no-key-shares",
+ "-server-supported-groups-hint", strconv.Itoa(int(CurveP256)),
+ "-expect-hrr",
+ },
+ shouldFail: true,
+ expectedError: ":HANDSHAKE_FAILURE_ON_CLIENT_HELLO:",
+ })
+
+ testCases = append(testCases, testCase{
testType: serverTest,
name: "SkipEarlyData-TLS13",
config: Config{
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index 0cfdc37..bce633a 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -367,6 +367,8 @@
IntVectorFlag("-curves", &TestConfig::curves),
OptionalIntVectorFlag("-key-shares", &TestConfig::key_shares),
OptionalDefaultInitFlag("-no-key-shares", &TestConfig::key_shares),
+ IntVectorFlag("-server-supported-groups-hint",
+ &TestConfig::server_supported_groups_hint),
StringFlag("-trust-cert", &TestConfig::trust_cert),
StringFlag("-expect-server-name", &TestConfig::expect_server_name),
BoolFlag("-enable-ech-grease", &TestConfig::enable_ech_grease),
@@ -2532,6 +2534,12 @@
key_shares->size())) {
return nullptr;
}
+ if (!server_supported_groups_hint.empty() &&
+ !SSL_set1_server_supported_groups_hint(
+ ssl.get(), server_supported_groups_hint.data(),
+ server_supported_groups_hint.size())) {
+ return nullptr;
+ }
if (initial_timeout_duration_ms > 0) {
DTLSv1_set_initial_timeout_duration(ssl.get(), initial_timeout_duration_ms);
}
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index 3b2f4e1..9745de0 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -66,6 +66,7 @@
std::vector<uint16_t> expect_peer_verify_prefs;
std::vector<uint16_t> curves;
std::optional<std::vector<uint16_t>> key_shares;
+ std::vector<uint16_t> server_supported_groups_hint;
std::string key_file;
std::string cert_file;
std::string trust_cert;