track known extensions rather than the smallest 64 (otherwise we cannot track the draft codepoints of ECH extensions)
diff --git a/lib/picotls.c b/lib/picotls.c
index 8188234..1efa8a2 100644
--- a/lib/picotls.c
+++ b/lib/picotls.c
@@ -419,7 +419,7 @@
};
struct st_ptls_extension_bitmap_t {
- uint8_t bits[8]; /* only ids below 64 is tracked */
+ uint64_t bits;
};
static const uint8_t zeroes_of_max_digest_size[PTLS_MAX_DIGEST_SIZE] = {0};
@@ -436,27 +436,17 @@
return 0;
}
-static inline int extension_bitmap_is_set(struct st_ptls_extension_bitmap_t *bitmap, uint16_t id)
-{
- if (id < sizeof(bitmap->bits) * 8)
- return (bitmap->bits[id / 8] & (1 << (id % 8))) != 0;
- return 0;
-}
-
-static inline void extension_bitmap_set(struct st_ptls_extension_bitmap_t *bitmap, uint16_t id)
-{
- if (id < sizeof(bitmap->bits) * 8)
- bitmap->bits[id / 8] |= 1 << (id % 8);
-}
-
-static inline void init_extension_bitmap(struct st_ptls_extension_bitmap_t *bitmap, unsigned hstype)
+static int extension_bitmap_testandset(struct st_ptls_extension_bitmap_t *bitmap, int hstype, uint16_t extid)
{
#define HSTYPE_TO_BIT(hstype) ((uint64_t)1 << ((hstype) + 1)) /* min(hstype) is -1 (PSEUDO_HRR) */
#define DEFINE_BIT(abbrev, hstype) static const uint64_t abbrev = HSTYPE_TO_BIT(PTLS_HANDSHAKE_TYPE_##hstype)
-#define EXT(extid, allowed_bits) \
+#define EXT(candext, allowed_bits) \
do { \
- if (((allowed_bits)&HSTYPE_TO_BIT(hstype)) == 0) \
- extension_bitmap_set(bitmap, PTLS_EXTENSION_TYPE_##extid); \
+ if (PTLS_UNLIKELY(extid == PTLS_EXTENSION_TYPE_##candext)) { \
+ allowed_hs_bits = allowed_bits; \
+ goto Found; \
+ } \
+ ext_bitmap_mask <<= 1; \
} while (0)
DEFINE_BIT(CH, CLIENT_HELLO);
@@ -467,7 +457,7 @@
DEFINE_BIT(CT, CERTIFICATE);
DEFINE_BIT(NST, NEW_SESSION_TICKET);
- *bitmap = (struct st_ptls_extension_bitmap_t){{0}};
+ uint64_t allowed_hs_bits, ext_bitmap_mask = 1;
/* clang-format off */
/* RFC 8446 section 4.2: "The table below indicates the messages where a given extension may appear... If an implementation
@@ -494,6 +484,16 @@
/* +-----------------------------------------+ */
/* clang-format on */
+ return 1;
+
+Found:
+ if ((allowed_hs_bits & HSTYPE_TO_BIT(hstype)) == 0)
+ return 0;
+ if ((bitmap->bits & ext_bitmap_mask) != 0)
+ return 0;
+ bitmap->bits |= ext_bitmap_mask;
+ return 1;
+
#undef HSTYPE_TO_BIT
#undef DEFINE_ABBREV
#undef EXT
@@ -871,17 +871,15 @@
#define decode_open_extensions(src, end, hstype, exttype, block) \
do { \
- struct st_ptls_extension_bitmap_t bitmap; \
- init_extension_bitmap(&bitmap, (hstype)); \
+ struct st_ptls_extension_bitmap_t bitmap = {0}; \
ptls_decode_open_block((src), end, 2, { \
while ((src) != end) { \
if ((ret = ptls_decode16((exttype), &(src), end)) != 0) \
goto Exit; \
- if (extension_bitmap_is_set(&bitmap, *(exttype)) != 0) { \
+ if (!extension_bitmap_testandset(&bitmap, (hstype), *(exttype))) { \
ret = PTLS_ALERT_ILLEGAL_PARAMETER; \
goto Exit; \
} \
- extension_bitmap_set(&bitmap, *(exttype)); \
ptls_decode_open_block((src), end, 2, block); \
} \
}); \
@@ -3727,7 +3725,6 @@
}); \
} while (0)
- uint16_t exttype;
int ret;
ptls_buffer_push_message_body(buf, NULL, PTLS_HANDSHAKE_TYPE_CLIENT_HELLO, {
@@ -3761,34 +3758,41 @@
/* extensions */
ptls_buffer_push_block(buf, 2, {
- decode_open_extensions(src, end, PTLS_HANDSHAKE_TYPE_CLIENT_HELLO, &exttype, {
- if (exttype == PTLS_EXTENSION_TYPE_ECH_OUTER_EXTENSIONS) {
- ptls_decode_open_block(src, end, 1, {
- do {
- uint16_t reftype;
- uint16_t outertype;
- uint16_t outersize;
- if ((ret = ptls_decode16(&reftype, &src, end)) != 0)
- goto Exit;
- while (1) {
- if ((ret = ptls_decode16(&outertype, &outer_ext, outer_ext_end)) != 0 ||
- (ret = ptls_decode16(&outersize, &outer_ext, outer_ext_end)) != 0)
- goto Exit;
- assert(outer_ext_end - outer_ext >= outersize);
- if (outertype == reftype)
- break;
- outer_ext += outersize;
- }
- buffer_push_extension(buf, reftype, {
- ptls_buffer_pushv(buf, outer_ext, outersize);
- outer_ext += outersize;
+ ptls_decode_open_block(src, end, 2, {
+ while (src != end) {
+ uint16_t exttype;
+ if ((ret = ptls_decode16(&exttype, &src, end)) != 0)
+ goto Exit;
+ ptls_decode_open_block(src, end, 2, {
+ if (exttype == PTLS_EXTENSION_TYPE_ECH_OUTER_EXTENSIONS) {
+ ptls_decode_open_block(src, end, 1, {
+ do {
+ uint16_t reftype;
+ uint16_t outertype;
+ uint16_t outersize;
+ if ((ret = ptls_decode16(&reftype, &src, end)) != 0)
+ goto Exit;
+ while (1) {
+ if ((ret = ptls_decode16(&outertype, &outer_ext, outer_ext_end)) != 0 ||
+ (ret = ptls_decode16(&outersize, &outer_ext, outer_ext_end)) != 0)
+ goto Exit;
+ assert(outer_ext_end - outer_ext >= outersize);
+ if (outertype == reftype)
+ break;
+ outer_ext += outersize;
+ }
+ buffer_push_extension(buf, reftype, {
+ ptls_buffer_pushv(buf, outer_ext, outersize);
+ outer_ext += outersize;
+ });
+ } while (src != end);
});
- } while (src != end);
- });
- } else {
- buffer_push_extension(buf, exttype, {
- ptls_buffer_pushv(buf, src, end - src);
- src = end;
+ } else {
+ buffer_push_extension(buf, exttype, {
+ ptls_buffer_pushv(buf, src, end - src);
+ src = end;
+ });
+ }
});
}
});
diff --git a/t/picotls.c b/t/picotls.c
index 4277350..f2ee970 100644
--- a/t/picotls.c
+++ b/t/picotls.c
@@ -44,6 +44,18 @@
ok(ptls_server_name_is_ipaddr("2001:db8::2:1"));
}
+static void test_extension_bitmap(void)
+{
+ struct st_ptls_extension_bitmap_t bitmap = {0};
+
+ /* disallowed extension is rejected */
+ ok(!extension_bitmap_testandset(&bitmap, PTLS_HANDSHAKE_TYPE_SERVER_HELLO, PTLS_EXTENSION_TYPE_COOKIE));
+
+ /* allowed extension is accepted first, rejected upon repetition */
+ ok(extension_bitmap_testandset(&bitmap, PTLS_HANDSHAKE_TYPE_SERVER_HELLO, PTLS_EXTENSION_TYPE_KEY_SHARE));
+ ok(!extension_bitmap_testandset(&bitmap, PTLS_HANDSHAKE_TYPE_SERVER_HELLO, PTLS_EXTENSION_TYPE_KEY_SHARE));
+}
+
static void test_select_cipher(void)
{
#define C(x) ((x) >> 8) & 0xff, (x)&0xff
@@ -1834,7 +1846,8 @@
void test_picotls(void)
{
subtest("is_ipaddr", test_is_ipaddr);
- subtest("select_cypher", test_select_cipher);
+ subtest("extension_bitmap", test_extension_bitmap);
+ subtest("select_cipher", test_select_cipher);
subtest("sha256", test_sha256);
subtest("sha384", test_sha384);
subtest("hmac-sha256", test_hmac_sha256);