Allow on_client_hello to swap ptls_context_t pointers, checking for cert types only after the callback has run
diff --git a/include/picotls.h b/include/picotls.h
index 8c179c5..f252865 100644
--- a/include/picotls.h
+++ b/include/picotls.h
@@ -535,6 +535,10 @@
const uint16_t *list;
size_t count;
} cipher_suites;
+ struct {
+ const uint8_t *list;
+ size_t count;
+ } server_certificate_types;
/**
* if ESNI was used
*/
diff --git a/lib/picotls.c b/lib/picotls.c
index ba4d74c..2e88b74 100644
--- a/lib/picotls.c
+++ b/lib/picotls.c
@@ -285,6 +285,7 @@
#define MAX_UNKNOWN_EXTENSIONS 16
#define MAX_CLIENT_CIPHERS 32
+#define MAX_CERTIFICATE_TYPES 2
struct st_ptls_client_hello_t {
uint16_t legacy_version;
@@ -336,6 +337,10 @@
unsigned early_data_indication : 1;
unsigned is_last_extension : 1;
} psk;
+ struct {
+ uint8_t list[MAX_CERTIFICATE_TYPES];
+ size_t count;
+ } server_certificate_types;
ptls_raw_extension_t unknown_extensions[MAX_UNKNOWN_EXTENSIONS + 1];
unsigned status_request : 1;
};
@@ -3318,7 +3323,6 @@
break;
case PTLS_EXTENSION_TYPE_SERVER_CERTIFICATE_TYPE:
ptls_decode_block(src, end, 1, {
- int found = 0;
size_t list_size = end - src;
/* RFC7250 4.1: No empty list, no list with single x509 element */
@@ -3328,17 +3332,9 @@
}
for (size_t i = 0; i < end - src; i++) {
- if ((*src == PTLS_CERTIFICATE_TYPE_X509 && !tls->ctx->use_raw_public_keys) ||
- (*src == PTLS_CERTIFICATE_TYPE_RAW_PUBLIC_KEY && tls->ctx->use_raw_public_keys)) {
- found = 1;
- break;
- }
+ if (ch->server_certificate_types.count < PTLS_ELEMENTSOF(ch->server_certificate_types.list))
+ ch->server_certificate_types.list[ch->server_certificate_types.count++] = *src;
}
- if (!found) {
- ret = PTLS_ALERT_UNSUPPORTED_CERTIFICATE;
- goto Exit;
- }
- src = end;
});
break;
case PTLS_EXTENSION_TYPE_COMPRESS_CERTIFICATE:
@@ -3692,7 +3688,7 @@
}
*ch = (struct st_ptls_client_hello_t){0, NULL, {NULL}, {NULL}, 0, {NULL}, {NULL}, {NULL}, {{0}},
- {NULL}, {NULL}, {{{NULL}}}, {{0}}, {{0}}, {{NULL}}, {NULL}, {{UINT16_MAX}}};
+ {NULL}, {NULL}, {{{NULL}}}, {{0}}, {{0}}, {{NULL}}, {NULL}, {{0}}, {{UINT16_MAX}}};
/* decode ClientHello */
if ((ret = decode_client_hello(tls, ch, message.base + PTLS_HANDSHAKE_HEADER_SIZE, message.base + message.len, properties)) !=
@@ -3778,15 +3774,41 @@
{ch->signature_algorithms.list, ch->signature_algorithms.count},
{ch->cert_compression_algos.list, ch->cert_compression_algos.count},
{ch->client_ciphers.list, ch->client_ciphers.count},
+ {ch->server_certificate_types.list, ch->server_certificate_types.count},
is_esni};
ret = tls->ctx->on_client_hello->cb(tls->ctx->on_client_hello, tls, ¶ms);
} else {
ret = 0;
}
+
if (is_esni)
free(server_name.base);
if (ret != 0)
goto Exit;
+ int cert_type_found = 0;
+ if (tls->ctx->use_raw_public_keys) {
+ for (size_t i = 0; i < ch->server_certificate_types.count; i++) {
+ if (ch->server_certificate_types.list[i] == PTLS_CERTIFICATE_TYPE_RAW_PUBLIC_KEY) {
+ cert_type_found = 1;
+ break;
+ }
+ }
+ } else {
+ if (ch->server_certificate_types.count != 0) {
+ for (size_t i = 0; i < ch->server_certificate_types.count; i++) {
+ if (ch->server_certificate_types.list[i] == PTLS_CERTIFICATE_TYPE_X509) {
+ cert_type_found = 1;
+ break;
+ }
+ }
+ } else {
+ cert_type_found = 1;
+ }
+ }
+ if (!cert_type_found) {
+ ret = PTLS_ALERT_UNSUPPORTED_CERTIFICATE;
+ goto Exit;
+ }
} else {
if (ch->psk.early_data_indication) {
ret = PTLS_ALERT_DECODE_ERROR;