Add API for configuring client key shares
This change introduces a new API to allow callers to configure the exact
set of key shares to be sent in a client's key_share extension. If the
supported groups for a connection are modified, any previously-selected
key shares are cleared if they are no longer compatible. Clients are
allowed to configure an empty list of key shares, which results in
always taking a round-trip for HelloRetryRequest.
Bug: 437414371
Change-Id: Ibd2a9b217fa3c746dec194a027d3ec45a81bb578
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/81567
Commit-Queue: Lily Chen <chlily@google.com>
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/crypto/mem_internal.h b/crypto/mem_internal.h
index 20b45fd..5ba44dc 100644
--- a/crypto/mem_internal.h
+++ b/crypto/mem_internal.h
@@ -379,6 +379,8 @@
template <typename T, size_t N>
class InplaceVector {
public:
+ using value_type = std::remove_cv_t<T>;
+
InplaceVector() = default;
InplaceVector(const InplaceVector &other) { *this = other; }
InplaceVector(InplaceVector &&other) { *this = std::move(other); }
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 79eb0c7..81de2fd 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -2605,6 +2605,47 @@
OPENSSL_EXPORT int SSL_get_negotiated_group(const SSL *ssl);
+// Client key shares.
+//
+// The key_share extension in TLS 1.3 (RFC 8446 section 4.2.8) may be sent in
+// the initial ClientHello to provide key exchange parameters for a subset of
+// the groups offered in the client's supported_groups extension, in hopes of
+// saving a round-trip by having proactively started a key exchange for the
+// ultimately-negotiated group.
+//
+// If not otherwise configured, the default client key share selection logic
+// outputs key shares for up to two supported groups, at most one of which is
+// post-quantum.
+
+// SSL_set1_client_key_shares, when called by a client before the handshake,
+// configures |ssl| to send a key_share extension in the initial ClientHello
+// containing exactly the groups given by |group_ids|, in the order given. Each
+// member of |group_ids| should be one of the |SSL_GROUP_*| constants, and they
+// must be unique. This function returns one on success and zero on failure.
+//
+// If non-empty, the sequence of |group_ids| must be a (not necessarily
+// contiguous) subsequence of the groups supported by |ssl|, which may have been
+// configured explicitly on |ssl| or its context, or populated by default.
+// Caller should finish configuring the group list before calling this function.
+// Changing the supported groups for |ssl| after having set client key shares
+// will result in the key share selections being reset if this constraint no
+// longer holds.
+//
+// Setting an empty sequence of |group_ids| results in an empty client
+// key_share, which will cause the handshake to always take an extra round-trip
+// for HelloRetryRequest.
+//
+// An extra round-trip will be needed if the server's choice of group is not
+// among the key shares sent; conversely, sending any key shares other than the
+// server's choice wastes CPU and bandwidth (the latter is particularly costly
+// for post-quantum key exchanges). To avoid these sub-optimal outcomes,
+// key shares should be chosen such that they are likely to be supported by the
+// peer server.
+OPENSSL_EXPORT int SSL_set1_client_key_shares(SSL *ssl,
+ const uint16_t *group_ids,
+ size_t num_group_ids);
+
+
// Certificate verification.
//
// SSL may authenticate either endpoint with an X.509 certificate. Typically
diff --git a/ssl/extensions.cc b/ssl/extensions.cc
index a09594e..2169966 100644
--- a/ssl/extensions.cc
+++ b/ssl/extensions.cc
@@ -2197,6 +2197,45 @@
return true;
}
+ const Span<const uint16_t> supported_group_list =
+ hs->config->supported_group_list;
+ if (supported_group_list.empty()) {
+ OPENSSL_PUT_ERROR(SSL, SSL_R_NO_GROUPS_SPECIFIED);
+ return false;
+ }
+
+ InplaceVector<uint16_t, 2> default_key_shares;
+
+ Span<const uint16_t> selected_key_shares;
+
+ // Determine the key shares to send.
+ if (override_group_id != 0) {
+ assert(std::find(supported_group_list.begin(), supported_group_list.end(),
+ override_group_id) != supported_group_list.end());
+ selected_key_shares = Span(&override_group_id, 1u);
+ } else if (ssl->config->client_key_share_selections.has_value()) {
+ selected_key_shares = *(ssl->config->client_key_share_selections);
+ } else {
+ // By default, predict the most preferred group.
+ if (!default_key_shares.TryPushBack(supported_group_list[0])) {
+ return false;
+ }
+ // We'll try to include one post-quantum and one classical initial key
+ // share.
+ for (size_t i = 1; i < supported_group_list.size(); i++) {
+ if (is_post_quantum_group(default_key_shares[0]) ==
+ is_post_quantum_group(supported_group_list[i])) {
+ continue;
+ }
+ if (!default_key_shares.TryPushBack(supported_group_list[i])) {
+ return false;
+ }
+ assert(default_key_shares[1] != default_key_shares[0]);
+ break;
+ }
+ selected_key_shares = default_key_shares;
+ }
+
bssl::ScopedCBB cbb;
if (!CBB_init(cbb.get(), 64)) {
return false;
@@ -2211,45 +2250,13 @@
}
}
- uint16_t group_id = override_group_id;
- uint16_t second_group_id = 0;
- if (override_group_id == 0) {
- // Predict the most preferred group.
- Span<const uint16_t> groups = hs->config->supported_group_list;
- if (groups.empty()) {
- OPENSSL_PUT_ERROR(SSL, SSL_R_NO_GROUPS_SPECIFIED);
- return false;
- }
-
- group_id = groups[0];
-
- // We'll try to include one post-quantum and one classical initial key
- // share.
- for (size_t i = 1; i < groups.size() && second_group_id == 0; i++) {
- if (is_post_quantum_group(group_id) != is_post_quantum_group(groups[i])) {
- second_group_id = groups[i];
- assert(second_group_id != group_id);
- }
- }
- }
-
CBB key_exchange;
- {
+ for (const uint16_t group_id : selected_key_shares) {
UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(group_id);
- if (key_share == nullptr || !CBB_add_u16(cbb.get(), group_id) ||
- !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) ||
- !key_share->Generate(&key_exchange) ||
- !hs->key_shares.TryPushBack(std::move(key_share))) {
- return false;
- }
- }
-
- if (second_group_id != 0) {
- // TODO(chlily): Fix temporary code duplication.
- UniquePtr<SSLKeyShare> key_share = SSLKeyShare::Create(second_group_id);
- if (key_share == nullptr || !CBB_add_u16(cbb.get(), second_group_id) ||
- !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) ||
- !key_share->Generate(&key_exchange) ||
+ if (key_share == nullptr || //
+ !CBB_add_u16(cbb.get(), group_id) || //
+ !CBB_add_u16_length_prefixed(cbb.get(), &key_exchange) || //
+ !key_share->Generate(&key_exchange) || //
!hs->key_shares.TryPushBack(std::move(key_share))) {
return false;
}
@@ -2267,7 +2274,12 @@
return true;
}
- assert(!hs->key_share_bytes.empty());
+ // The caller may explicitly configure empty key shares to request a
+ // HelloRetryRequest.
+ assert(!hs->key_share_bytes.empty() ||
+ (hs->config->client_key_share_selections.has_value() &&
+ hs->config->client_key_share_selections->empty()));
+
CBB contents, kse_bytes;
if (!CBB_add_u16(out_compressible, TLSEXT_TYPE_key_share) ||
!CBB_add_u16_length_prefixed(out_compressible, &contents) ||
diff --git a/ssl/internal.h b/ssl/internal.h
index 37eb54e..d72f7a5 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -1770,9 +1770,7 @@
UniquePtr<ERR_SAVE_STATE> error;
// key_shares are the current key exchange instances, in preference order. Any
- // members of this vector must be non-null. By default, no more than two are
- // used, and the second is only used as a client if we believe that we should
- // offer two key shares in a ClientHello.
+ // members of this vector must be non-null.
InplaceVector<UniquePtr<SSLKeyShare>, kNumNamedGroups> key_shares;
// transcript is the current handshake transcript.
@@ -2130,9 +2128,19 @@
bool ssl_setup_extension_permutation(SSL_HANDSHAKE *hs);
// ssl_setup_key_shares computes client key shares and saves them in |hs|. It
-// returns true on success and false on failure. If |override_group_id| is zero,
-// it offers the default groups, including GREASE. If it is non-zero, it offers
-// a single key share of the specified group.
+// returns true on success and false on failure. In order of precedence:
+//
+// - If |override_group_id| is non-zero, it offers a single key share of the
+// specified group.
+//
+// - If key shares were previously specified by the caller via
+// |SSL_set_client_key_shares|, those are used.
+//
+// - If |override_group_id| is zero and no selections were made by the caller
+// via |SSL_set_client_key_shares|, it selects the first supported group and
+// may select a second if at most one of the two is a post-quantum group.
+//
+// GREASE will be included if enabled, when |override_group_id| is zero.
bool ssl_setup_key_shares(SSL_HANDSHAKE *hs, uint16_t override_group_id);
// ssl_setup_pake_shares computes the client PAKE shares and saves them in |hs|.
@@ -3275,7 +3283,19 @@
// Trust anchor IDs to be requested in the trust_anchors extension.
std::optional<Array<uint8_t>> requested_trust_anchors;
- Array<uint16_t> supported_group_list; // our list
+ // Our list of supported groups. If this list is modified, for a client,
+ // |client_key_share_selections| must be reset if the key shares are no longer
+ // a valid subsequence of the supported group list.
+ Array<uint16_t> supported_group_list;
+
+ // For a client, this may contain a subsequence of the group IDs in
+ // |suppported_group_list|, which gives the groups for which key shares should
+ // be sent in the client's key_share extension. This is non-nullopt iff
+ // |SSL_set_client_key_shares| was successfully called to configure key
+ // shares. If non-nullopt, these groups are in the same order as they appear
+ // in |supported_group_list|, and may not contain duplicates.
+ std::optional<InplaceVector<uint16_t, kNumNamedGroups>>
+ client_key_share_selections;
// channel_id_private is the client's Channel ID private key, or null if
// Channel ID should not be offered on this connection.
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 2e0db35..532744f 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -1844,6 +1844,37 @@
return check_no_duplicates(group_ids);
}
+// validate_key_shares returns whether the `requested_key_shares` are free of
+// duplicates and are a (correctly ordered) subsequence of the supported
+// `groups`.
+static bool validate_key_shares(Span<const uint16_t> requested_key_shares,
+ Span<const uint16_t> groups) {
+ if (!check_no_duplicates(requested_key_shares)) {
+ return false;
+ }
+ if (requested_key_shares.size() > groups.size()) {
+ return false;
+ }
+ size_t key_shares_idx = 0u, groups_idx = 0u;
+ while (key_shares_idx < requested_key_shares.size() &&
+ groups_idx < groups.size()) {
+ if (requested_key_shares[key_shares_idx] == groups[groups_idx++]) {
+ ++key_shares_idx;
+ }
+ }
+ return key_shares_idx == requested_key_shares.size();
+}
+
+static void clear_key_shares_if_invalid(SSL_CONFIG *config) {
+ if (!config->client_key_share_selections) {
+ return;
+ }
+ if (!validate_key_shares(*(config->client_key_share_selections),
+ config->supported_group_list)) {
+ config->client_key_share_selections.reset();
+ }
+}
+
int SSL_CTX_set1_group_ids(SSL_CTX *ctx, const uint16_t *group_ids,
size_t num_group_ids) {
auto span = Span(group_ids, num_group_ids);
@@ -1862,8 +1893,12 @@
if (span.empty()) {
span = DefaultSupportedGroupIds();
}
- return check_group_ids(span) &&
- ssl->config->supported_group_list.CopyFrom(span);
+ if (check_group_ids(span) &&
+ ssl->config->supported_group_list.CopyFrom(span)) {
+ clear_key_shares_if_invalid(ssl->config.get());
+ return 1;
+ }
+ return 0;
}
static bool ssl_nids_to_group_ids(Array<uint16_t> *out_group_ids,
@@ -1899,8 +1934,12 @@
if (!ssl->config) {
return 0;
}
- return ssl_nids_to_group_ids(&ssl->config->supported_group_list,
- Span(groups, num_groups));
+ if (ssl_nids_to_group_ids(&ssl->config->supported_group_list,
+ Span(groups, num_groups))) {
+ clear_key_shares_if_invalid(ssl->config.get());
+ return 1;
+ }
+ return 0;
}
static bool ssl_str_to_group_ids(Array<uint16_t> *out_group_ids,
@@ -1951,7 +1990,11 @@
if (!ssl->config) {
return 0;
}
- return ssl_str_to_group_ids(&ssl->config->supported_group_list, groups);
+ if (ssl_str_to_group_ids(&ssl->config->supported_group_list, groups)) {
+ clear_key_shares_if_invalid(ssl->config.get());
+ return 1;
+ }
+ return 0;
}
uint16_t SSL_get_group_id(const SSL *ssl) {
@@ -1971,6 +2014,23 @@
return ssl_group_id_to_nid(group_id);
}
+int SSL_set1_client_key_shares(SSL *ssl, const uint16_t *group_ids,
+ size_t num_group_ids) {
+ if (!ssl->config) {
+ return 0;
+ }
+ auto requested_key_shares = Span(group_ids, num_group_ids);
+ if (!validate_key_shares(requested_key_shares,
+ ssl->config->supported_group_list)) {
+ return 0;
+ }
+
+ assert(requested_key_shares.size() <= kNumNamedGroups);
+ ssl->config->client_key_share_selections.emplace();
+ ssl->config->client_key_share_selections->CopyFrom(requested_key_shares);
+ return 1;
+}
+
int SSL_CTX_set_tmp_dh(SSL_CTX *ctx, const DH *dh) { return 1; }
int SSL_set_tmp_dh(SSL *ssl, const DH *dh) { return 1; }
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 96edb61..1b34c14 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -722,6 +722,146 @@
}
}
+TEST(SSLTest, SetClientKeyShares) {
+ const struct {
+ const char *description;
+ std::vector<uint16_t> supported_groups;
+ std::vector<uint16_t> key_shares;
+ bool expected_success;
+ } kTests[] = {
+ {
+ "Empty key shares with default supported groups",
+ {},
+ {},
+ true,
+ },
+ {
+ "Empty key shares with custom supported groups",
+ {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768},
+ {},
+ true,
+ },
+ {
+ "One key share matching default supported groups",
+ {},
+ {SSL_GROUP_X25519},
+ true,
+ },
+ {
+ "One key share matching custom supported groups",
+ {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768},
+ {SSL_GROUP_X25519},
+ true,
+ },
+ {
+ "Key share not in supported default groups",
+ {},
+ {SSL_GROUP_MLKEM1024},
+ false,
+ },
+ {
+ "Key share not in supported custom groups",
+ {SSL_GROUP_X25519, SSL_GROUP_SECP256R1},
+ {SSL_GROUP_X25519_MLKEM768},
+ false,
+ },
+ {
+ "Multiple key shares, in correct order",
+ {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+ {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768},
+ true,
+ },
+ {
+ "Multiple key shares, out of order",
+ {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+ {SSL_GROUP_X25519_MLKEM768, SSL_GROUP_X25519},
+ false,
+ },
+ {
+ "More than two key shares",
+ {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768,
+ SSL_GROUP_MLKEM1024},
+ {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+ true,
+ },
+ {
+ "Key shares cover all supported groups",
+ {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+ {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+ true,
+ },
+ {
+ "Multiple key shares, not all valid",
+ {SSL_GROUP_X25519, SSL_GROUP_X25519_MLKEM768},
+ {SSL_GROUP_X25519, SSL_GROUP_SECP256R1, SSL_GROUP_X25519_MLKEM768},
+ false,
+ },
+ {
+ "Key shares contain duplicates",
+ {},
+ {SSL_GROUP_X25519, SSL_GROUP_X25519},
+ false,
+ },
+ };
+
+ for (const auto &t : kTests) {
+ SCOPED_TRACE(t.description);
+ bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+ ASSERT_TRUE(ctx);
+ bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get()));
+ ASSERT_TRUE(ssl);
+ ASSERT_FALSE(ssl->config->client_key_share_selections.has_value());
+
+ ASSERT_TRUE(SSL_set1_group_ids(ssl.get(), t.supported_groups.data(),
+ t.supported_groups.size()));
+ EXPECT_EQ(SSL_set1_client_key_shares(ssl.get(), t.key_shares.data(),
+ t.key_shares.size()),
+ t.expected_success);
+ if (t.expected_success) {
+ ASSERT_TRUE(ssl->config->client_key_share_selections.has_value());
+ EXPECT_THAT(ssl->config->client_key_share_selections.value(),
+ ElementsAreArray(t.key_shares));
+ }
+ }
+}
+
+// Test the behavior that modifying the SSL's supported groups results in
+// clearing the previously set client key shares, iff the supported groups
+// become incompatible with the key shares.
+TEST(SSLTest, ClientKeySharesResetAfterChangingGroups) {
+ bssl::UniquePtr<SSL_CTX> ctx(SSL_CTX_new(TLS_method()));
+ ASSERT_TRUE(ctx);
+ bssl::UniquePtr<SSL> ssl(SSL_new(ctx.get()));
+ ASSERT_TRUE(ssl);
+ ASSERT_FALSE(ssl->config->client_key_share_selections.has_value());
+
+ // An initial groups list and key shares that are compatible.
+ const uint16_t kGroups1[] = {SSL_GROUP_X25519_MLKEM768, SSL_GROUP_X25519};
+ const uint16_t kKeyShares[] = {SSL_GROUP_X25519_MLKEM768, SSL_GROUP_X25519};
+ ASSERT_TRUE(
+ SSL_set1_group_ids(ssl.get(), kGroups1, std::size(kGroups1)));
+ ASSERT_TRUE(SSL_set1_client_key_shares(ssl.get(), kKeyShares,
+ std::size(kKeyShares)));
+ ASSERT_TRUE(ssl->config->client_key_share_selections.has_value());
+ EXPECT_EQ(ssl->config->client_key_share_selections->size(), 2u);
+
+ // A new groups list that is still compatible with the previously set key
+ // shares.
+ const uint16_t kGroups2[] = {SSL_GROUP_MLKEM1024, SSL_GROUP_X25519_MLKEM768,
+ SSL_GROUP_X25519};
+ ASSERT_TRUE(
+ SSL_set1_group_ids(ssl.get(), kGroups2, std::size(kGroups2)));
+ ASSERT_TRUE(ssl->config->client_key_share_selections.has_value());
+ EXPECT_EQ(ssl->config->client_key_share_selections->size(), 2u);
+
+ // A new groups list that is no longer compatible with the previously set key
+ // shares.
+ const uint16_t kGroups3[] = {SSL_GROUP_MLKEM1024, SSL_GROUP_X25519};
+ ASSERT_TRUE(
+ SSL_set1_group_ids(ssl.get(), kGroups3, std::size(kGroups3)));
+ EXPECT_FALSE(ssl->config->client_key_share_selections.has_value());
+}
+
// kOpenSSLSession is a serialized SSL_SESSION.
static const char kOpenSSLSession[] =
"MIIFqgIBAQICAwMEAsAvBCAG5Q1ndq4Yfmbeo1zwLkNRKmCXGdNgWvGT3cskV0yQ"
diff --git a/ssl/test/runner/pake_tests.go b/ssl/test/runner/pake_tests.go
index ddd2eb3..677a3a4 100644
--- a/ssl/test/runner/pake_tests.go
+++ b/ssl/test/runner/pake_tests.go
@@ -14,7 +14,10 @@
package runner
-import "errors"
+import (
+ "errors"
+ "strconv"
+)
func addPAKETests() {
spakeCredential := Credential{
@@ -231,17 +234,22 @@
shimCredentials: []*Credential{&spakeCredential},
})
testCases = append(testCases, testCase{
- // A PAKE client will not offer key shares, so the client should
- // reject a HelloRetryRequest requesting a different key share.
+ // A PAKE client will not offer key shares, even if explicitly configured,
+ // and the client should reject a HelloRetryRequest requesting a different
+ // key share.
name: "PAKE-Client-HRRKeyShare",
testType: clientTest,
config: Config{
MinVersion: VersionTLS13,
Credential: &spakeCredential,
Bugs: ProtocolBugs{
+ ExpectedKeyShares: []CurveID{},
SendHelloRetryRequestCurve: CurveX25519,
},
},
+ flags: []string{
+ "-key-shares", strconv.Itoa(int(CurveP256)),
+ },
shimCredentials: []*Credential{&spakeCredential},
shouldFail: true,
expectedError: ":UNEXPECTED_EXTENSION:",
diff --git a/ssl/test/runner/tls13_tests.go b/ssl/test/runner/tls13_tests.go
index 2b08ec8..b9826eb 100644
--- a/ssl/test/runner/tls13_tests.go
+++ b/ssl/test/runner/tls13_tests.go
@@ -132,6 +132,175 @@
})
testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "CustomKeyShares-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ Bugs: ProtocolBugs{
+ ExpectedKeyShares: []CurveID{CurveP256},
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveX25519)),
+ "-curves", strconv.Itoa(int(CurveP256)),
+ "-key-shares", strconv.Itoa(int(CurveP256)),
+ },
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "CustomKeyShares-DefaultSupportedGroups-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ Bugs: ProtocolBugs{
+ ExpectedKeyShares: []CurveID{CurveX25519},
+ },
+ },
+ flags: []string{
+ // Configure key shares without explicitly configuring supported groups.
+ "-key-shares", strconv.Itoa(int(CurveX25519)),
+ },
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "CustomKeyShares-Multiple-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ Bugs: ProtocolBugs{
+ ExpectedKeyShares: []CurveID{CurveX25519, CurveX25519MLKEM768},
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveX25519)),
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-curves", strconv.Itoa(int(CurveP256)),
+ // Predict the top 2 out of 3 supported curves.
+ "-key-shares", strconv.Itoa(int(CurveX25519)),
+ "-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-expect-curve-id", strconv.Itoa(int(CurveX25519)),
+ "-expect-no-hrr",
+ },
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "CustomKeyShares-MultipleNonContiguous-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ Bugs: ProtocolBugs{
+ ExpectedKeyShares: []CurveID{CurveX25519, CurveP256},
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveX25519)),
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-curves", strconv.Itoa(int(CurveP256)),
+ // Predict the first and last of 3 supported curves.
+ "-key-shares", strconv.Itoa(int(CurveX25519)),
+ "-key-shares", strconv.Itoa(int(CurveP256)),
+ "-expect-curve-id", strconv.Itoa(int(CurveX25519)),
+ "-expect-no-hrr",
+ },
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "CustomKeyShares-NotMostPreferred-HRR-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ Bugs: ProtocolBugs{
+ ExpectedKeyShares: []CurveID{CurveX25519MLKEM768, CurveP256},
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveX25519)),
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-curves", strconv.Itoa(int(CurveP256)),
+ // Predict some curves that we support, not including the top one.
+ "-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-key-shares", strconv.Itoa(int(CurveP256)),
+ "-expect-curve-id", strconv.Itoa(int(CurveX25519)),
+ // Check that we triggered a HelloRetryRequest.
+ "-expect-hrr",
+ },
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "CustomKeyShares-PeerSelectedLaterKeyShare-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ CurvePreferences: []CurveID{CurveP256},
+ Bugs: ProtocolBugs{
+ ExpectedKeyShares: []CurveID{CurveX25519MLKEM768, CurveP256},
+ // Make the server select one of the groups with a key_share, but not
+ // the most preferred one.
+ SendCurve: CurveP256,
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveX25519)),
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-curves", strconv.Itoa(int(CurveP256)),
+ "-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-key-shares", strconv.Itoa(int(CurveP256)),
+ "-expect-curve-id", strconv.Itoa(int(CurveP256)),
+ "-expect-no-hrr",
+ },
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "CustomKeyShares-All-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ Bugs: ProtocolBugs{
+ ExpectedKeyShares: []CurveID{
+ CurveP256,
+ CurveP384,
+ CurveP521,
+ CurveX25519,
+ CurveX25519Kyber768,
+ CurveX25519MLKEM768,
+ CurveMLKEM1024,
+ },
+ },
+ },
+ flags: []string{
+ "-curves", strconv.Itoa(int(CurveP256)),
+ "-curves", strconv.Itoa(int(CurveP384)),
+ "-curves", strconv.Itoa(int(CurveP521)),
+ "-curves", strconv.Itoa(int(CurveX25519)),
+ "-curves", strconv.Itoa(int(CurveX25519Kyber768)),
+ "-curves", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-curves", strconv.Itoa(int(CurveMLKEM1024)),
+ "-key-shares", strconv.Itoa(int(CurveP256)),
+ "-key-shares", strconv.Itoa(int(CurveP384)),
+ "-key-shares", strconv.Itoa(int(CurveP521)),
+ "-key-shares", strconv.Itoa(int(CurveX25519)),
+ "-key-shares", strconv.Itoa(int(CurveX25519Kyber768)),
+ "-key-shares", strconv.Itoa(int(CurveX25519MLKEM768)),
+ "-key-shares", strconv.Itoa(int(CurveMLKEM1024)),
+ },
+ })
+
+ testCases = append(testCases, testCase{
+ testType: clientTest,
+ name: "CustomKeyShares-Empty-HRR-TLS13",
+ config: Config{
+ MinVersion: VersionTLS13,
+ Bugs: ProtocolBugs{
+ ExpectedKeyShares: []CurveID{},
+ },
+ },
+ flags: []string{
+ "-no-key-shares",
+ "-expect-hrr",
+ },
+ })
+
+ testCases = append(testCases, testCase{
testType: serverTest,
name: "SkipEarlyData-TLS13",
config: Config{
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index f973d9e..0cfdc37 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -53,12 +53,17 @@
template <typename Config>
struct Flag {
const char *name;
+ // has_param, if true, causes the parser to look for a param value following
+ // this flag's name.
bool has_param;
// skip_handshaker, if true, causes this flag to be skipped when
// forwarding flags to the handshaker. This should be used with flags
// that only impact connecting to the runner.
bool skip_handshaker;
- // If |has_param| is false, |param| will be nullptr.
+ // set_param is called after parsing to interpret and set the result on
+ // `config`. If `has_param` is false for this flag, `param` will be nullptr.
+ // This function should return whether the param value (or lack thereof) was
+ // valid for this flag.
std::function<bool(Config *config, const char *param)> set_param;
};
@@ -168,6 +173,39 @@
}};
}
+// Defines a flag which adds an integer param value to an optional vector of
+// integers.
+template <typename Config, typename T>
+Flag<Config> OptionalIntVectorFlag(const char *name,
+ std::optional<std::vector<T>> Config::*field,
+ bool skip_handshaker = false) {
+ return Flag<Config>{name, true, skip_handshaker,
+ [=](Config *config, const char *param) -> bool {
+ if (!(config->*field)) {
+ (config->*field).emplace();
+ }
+ T value;
+ if (!StringToInt(&value, param)) {
+ return false;
+ }
+ (config->*field)->push_back(value);
+ return true;
+ }};
+}
+
+// Defines a flag which resets a std::optional field to its default constructed
+// value.
+template <typename Config, typename T>
+Flag<Config> OptionalDefaultInitFlag(const char *name,
+ std::optional<T> Config::*field,
+ bool skip_handshaker = false) {
+ return Flag<Config>{name, false, skip_handshaker,
+ [=](Config *config, const char *) -> bool {
+ (config->*field).emplace();
+ return true;
+ }};
+}
+
template <typename Config>
Flag<Config> StringFlag(const char *name, std::string Config::*field,
bool skip_handshaker = false) {
@@ -327,6 +365,8 @@
IntVectorFlag("-expect-peer-verify-pref",
&TestConfig::expect_peer_verify_prefs),
IntVectorFlag("-curves", &TestConfig::curves),
+ OptionalIntVectorFlag("-key-shares", &TestConfig::key_shares),
+ OptionalDefaultInitFlag("-no-key-shares", &TestConfig::key_shares),
StringFlag("-trust-cert", &TestConfig::trust_cert),
StringFlag("-expect-server-name", &TestConfig::expect_server_name),
BoolFlag("-enable-ech-grease", &TestConfig::enable_ech_grease),
@@ -679,6 +719,9 @@
if (!skip) {
if (out != nullptr) {
if (!flag->set_param(out, param)) {
+ if (!param) {
+ param = "(no parameter)";
+ }
fprintf(stderr, "Invalid parameter for %s: %s\n", name, param);
return false;
}
@@ -687,6 +730,9 @@
if (!flag->set_param(out_initial, param) ||
!flag->set_param(out_resume, param) ||
!flag->set_param(out_retry, param)) {
+ if (!param) {
+ param = "(no parameter)";
+ }
fprintf(stderr, "Invalid parameter for %s: %s\n", name, param);
return false;
}
@@ -2481,6 +2527,11 @@
!SSL_set1_group_ids(ssl.get(), curves.data(), curves.size())) {
return nullptr;
}
+ if (key_shares.has_value() &&
+ !SSL_set1_client_key_shares(ssl.get(), key_shares->data(),
+ key_shares->size())) {
+ return nullptr;
+ }
if (initial_timeout_duration_ms > 0) {
DTLSv1_set_initial_timeout_duration(ssl.get(), initial_timeout_duration_ms);
}
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index de2f5c5..3b2f4e1 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -65,6 +65,7 @@
std::vector<uint16_t> verify_prefs;
std::vector<uint16_t> expect_peer_verify_prefs;
std::vector<uint16_t> curves;
+ std::optional<std::vector<uint16_t>> key_shares;
std::string key_file;
std::string cert_file;
std::string trust_cert;
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index dc417b9..a664321 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -253,7 +253,11 @@
// The ECH extension, if present, was already parsed by
// |check_ech_confirmation|.
SSLExtension cookie(TLSEXT_TYPE_cookie),
- key_share(TLSEXT_TYPE_key_share, !hs->key_share_bytes.empty()),
+ // If offering PAKE, we won't send key_share extensions and we should
+ // reject key_share from the peer. Otherwise, it is valid to have sent an
+ // empty key_share extension, and expect the HelloRetryRequest to contain
+ // a key_share.
+ key_share(TLSEXT_TYPE_key_share, !hs->pake_prover),
supported_versions(TLSEXT_TYPE_supported_versions),
ech_unused(TLSEXT_TYPE_encrypted_client_hello,
hs->selected_ech_config || hs->config->ech_grease_enabled);
@@ -286,8 +290,6 @@
}
if (key_share.present) {
- // If offering PAKE, we won't send key_share extensions, in which case we
- // would have rejected key_share from the peer.
assert(!hs->pake_prover);
uint16_t group_id;