Check hs->early_session, not ssl->session, for the early data limit.

ServerHello/EncryptedExtensions/Finished is logically one atomic flight
that exits the early data state, we have process each message
sequentially. Until we've processed Finished, we are still in the early
data state and must support writing data. Individual messages *are*
processed atomically, so the interesting points are before ServerHello
(already tested), after ServerHello, and after EncryptedExtensions.

The TLS 1.3 handshake internally clears ssl->session when processing
ServerHello, so getting the early data information from ssl->session
does not work. Instead, use hs->early_session, which is what other
codepaths use.

I've tested this with runner rather than ssl_test, so we can test both
post-SH and post-EE states. ssl_test would be more self-contained, since
we can directly control the API calls, but it cannot test the post-EE
state. To reduce record overhead, our production implementation packs EE
and Finished into the same record, which means the handshake will
process the two atomically. Instead, I've tested this in runner, with a
flag to partially drive the handshake before reading early data.

I've also tweaked the logic to hopefully be a little clearer.

Bug: chromium:1208784
Change-Id: Ia4901042419c5324054f97743bd1aac59ebf8f24
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/47485
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index 457696d..450f7dc 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -112,6 +112,8 @@
 #include <limits.h>
 #include <string.h>
 
+#include <algorithm>
+
 #include <openssl/err.h>
 #include <openssl/evp.h>
 #include <openssl/mem.h>
@@ -138,10 +140,9 @@
     return -1;
   }
 
-  unsigned tot, n, nw;
-
+  // TODO(davidben): Switch this logic to |size_t| and |bssl::Span|.
   assert(ssl->s3->wnum <= INT_MAX);
-  tot = ssl->s3->wnum;
+  unsigned tot = ssl->s3->wnum;
   ssl->s3->wnum = 0;
 
   // Ensure that if we end up with a smaller value of data to write out than
@@ -159,29 +160,23 @@
   const int is_early_data_write =
       !ssl->server && SSL_in_early_data(ssl) && ssl->s3->hs->can_early_write;
 
-  n = len - tot;
+  unsigned n = len - tot;
   for (;;) {
-    // max contains the maximum number of bytes that we can put into a record.
-    unsigned max = ssl->max_send_fragment;
-    if (is_early_data_write &&
-        max > ssl->session->ticket_max_early_data -
-                  ssl->s3->hs->early_data_written) {
-      max =
-          ssl->session->ticket_max_early_data - ssl->s3->hs->early_data_written;
-      if (max == 0) {
+    size_t max_send_fragment = ssl->max_send_fragment;
+    if (is_early_data_write) {
+      SSL_HANDSHAKE *hs = ssl->s3->hs.get();
+      if (hs->early_data_written >= hs->early_session->ticket_max_early_data) {
         ssl->s3->wnum = tot;
-        ssl->s3->hs->can_early_write = false;
+        hs->can_early_write = false;
         *out_needs_handshake = true;
         return -1;
       }
+      max_send_fragment = std::min(
+          max_send_fragment, size_t{hs->early_session->ticket_max_early_data -
+                                    hs->early_data_written});
     }
 
-    if (n > max) {
-      nw = max;
-    } else {
-      nw = n;
-    }
-
+    const size_t nw = std::min(max_send_fragment, size_t{n});
     int ret = do_tls_write(ssl, SSL3_RT_APPLICATION_DATA, &in[tot], nw);
     if (ret <= 0) {
       ssl->s3->wnum = tot;
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index 8931349..9438c1f 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -851,6 +851,7 @@
   int ret;
   SSL *ssl = ssl_uniqueptr->get();
   SSL_CTX *session_ctx = SSL_get_SSL_CTX(ssl);
+  TestState *test_state = GetTestState(ssl);
 
   if (!config->implicit_handshake) {
     if (config->handoff) {
@@ -859,6 +860,7 @@
         return false;
       }
       ssl = ssl_uniqueptr->get();
+      test_state = GetTestState(ssl);
 #else
       fprintf(stderr, "The external handshaker can only be used on Linux\n");
       return false;
@@ -903,9 +905,44 @@
       return false;
     }
 
+    if (config->early_write_after_message != 0) {
+      if (!SSL_in_early_data(ssl) || config->is_server) {
+        fprintf(stderr,
+                "-early-write-after-message only works for 0-RTT connections "
+                "on servers.\n");
+        return false;
+      }
+      if (!config->shim_writes_first || !config->async) {
+        fprintf(stderr,
+                "-early-write-after-message requires -shim-writes-first and "
+                "-async.\n");
+        return false;
+      }
+      // Run the handshake until the specified message. Note that, if a
+      // handshake record contains multiple messages, |SSL_do_handshake| usually
+      // processes both atomically. The test must ensure there is a record
+      // boundary after the desired message. Checking |last_message_received|
+      // confirms this.
+      do {
+        ret = SSL_do_handshake(ssl);
+      } while (test_state->last_message_received !=
+                   config->early_write_after_message &&
+               RetryAsync(ssl, ret));
+      if (ret == 1) {
+        fprintf(stderr, "Handshake unexpectedly succeeded.\n");
+        return false;
+      }
+      if (test_state->last_message_received !=
+          config->early_write_after_message) {
+        // The handshake failed before we saw the target message. The generic
+        // error-handling logic in the caller will print the error.
+        return false;
+      }
+    }
+
     // Reset the state to assert later that the callback isn't called in
     // renegotations.
-    GetTestState(ssl)->got_new_session = false;
+    test_state->got_new_session = false;
   }
 
   if (config->export_keying_material > 0) {
@@ -1005,7 +1042,7 @@
       }
 
       // Let only one byte of the record through.
-      AsyncBioAllowWrite(GetTestState(ssl)->async_bio, 1);
+      AsyncBioAllowWrite(test_state->async_bio, 1);
       int write_ret =
           SSL_write(ssl, kInitialWrite, strlen(kInitialWrite));
       if (SSL_get_error(ssl, write_ret) != SSL_ERROR_WANT_WRITE) {
@@ -1060,7 +1097,7 @@
 
         // After a successful read, with or without False Start, the handshake
         // must be complete unless we are doing early data.
-        if (!GetTestState(ssl)->handshake_done &&
+        if (!test_state->handshake_done &&
             !SSL_early_data_accepted(ssl)) {
           fprintf(stderr, "handshake was not completed after SSL_read\n");
           return false;
@@ -1094,7 +1131,7 @@
       !config->implicit_handshake &&
       // Session tickets are sent post-handshake in TLS 1.3.
       GetProtocolVersion(ssl) < TLS1_3_VERSION &&
-      GetTestState(ssl)->got_new_session) {
+      test_state->got_new_session) {
     fprintf(stderr, "new session was established after the handshake\n");
     return false;
   }
@@ -1102,16 +1139,16 @@
   if (GetProtocolVersion(ssl) >= TLS1_3_VERSION && !config->is_server) {
     bool expect_new_session =
         !config->expect_no_session && !config->shim_shuts_down;
-    if (expect_new_session != GetTestState(ssl)->got_new_session) {
+    if (expect_new_session != test_state->got_new_session) {
       fprintf(stderr,
               "new session was%s cached, but we expected the opposite\n",
-              GetTestState(ssl)->got_new_session ? "" : " not");
+              test_state->got_new_session ? "" : " not");
       return false;
     }
 
     if (expect_new_session) {
       bool got_early_data =
-          GetTestState(ssl)->new_session->ticket_max_early_data != 0;
+          test_state->new_session->ticket_max_early_data != 0;
       if (config->expect_ticket_supports_early_data != got_early_data) {
         fprintf(stderr,
                 "new session did%s support early data, but we expected the "
@@ -1123,7 +1160,7 @@
   }
 
   if (out_session) {
-    *out_session = std::move(GetTestState(ssl)->new_session);
+    *out_session = std::move(test_state->new_session);
   }
 
   ret = DoShutdown(ssl);
@@ -1172,10 +1209,10 @@
 
   if (config->renegotiate_explicit &&
       SSL_total_renegotiations(ssl) !=
-          GetTestState(ssl)->explicit_renegotiates) {
+          test_state->explicit_renegotiates) {
     fprintf(stderr, "Performed %d renegotiations, but triggered %d of them\n",
             SSL_total_renegotiations(ssl),
-            GetTestState(ssl)->explicit_renegotiates);
+            test_state->explicit_renegotiates);
     return false;
   }
 
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index c169d2d..f802585 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -1060,7 +1060,7 @@
 		shimPrefix = test.resumeShimPrefix
 	}
 	if test.shimWritesFirst || test.readWithUnfinishedWrite {
-		shimPrefix = "hello"
+		shimPrefix = shimInitialWrite
 	}
 	if test.renegotiate > 0 {
 		// If readWithUnfinishedWrite is set, the shim prefix will be
@@ -1294,6 +1294,10 @@
 	return errorStr
 }
 
+// shimInitialWrite is the data we expect from the shim when the
+// -shim-writes-first flag is used.
+const shimInitialWrite = "hello"
+
 func runTest(statusChan chan statusMsg, test *testCase, shimPath string, mallocNumToFail int64) error {
 	// Help debugging panics on the Go side.
 	defer func() {
@@ -1433,7 +1437,7 @@
 			// Configure the shim to send some data in early data.
 			flags = append(flags, "-on-resume-shim-writes-first")
 			if resumeConfig.Bugs.ExpectEarlyData == nil {
-				resumeConfig.Bugs.ExpectEarlyData = [][]byte{[]byte("hello")}
+				resumeConfig.Bugs.ExpectEarlyData = [][]byte{[]byte(shimInitialWrite)}
 			}
 		} else {
 			// By default, send some early data and expect half-RTT data response.
@@ -4842,10 +4846,10 @@
 					MinVersion:       VersionTLS13,
 					MaxEarlyDataSize: 2,
 					Bugs: ProtocolBugs{
-						ExpectEarlyData: [][]byte{{'h', 'e'}},
+						ExpectEarlyData: [][]byte{[]byte(shimInitialWrite[:2])},
 					},
 				},
-				resumeShimPrefix: "llo",
+				resumeShimPrefix: shimInitialWrite[2:],
 				resumeSession:    true,
 				earlyData:        true,
 			})
@@ -4865,8 +4869,9 @@
 					MaxVersion: VersionTLS13,
 					MinVersion: VersionTLS13,
 					Bugs: ProtocolBugs{
+						// Write the server response before expecting early data.
 						ExpectEarlyData:     [][]byte{},
-						ExpectLateEarlyData: [][]byte{{'h', 'e', 'l', 'l', 'o'}},
+						ExpectLateEarlyData: [][]byte{[]byte(shimInitialWrite)},
 					},
 				},
 				resumeSession: true,
@@ -15147,6 +15152,51 @@
 		expectedError:      ":CIPHER_MISMATCH_ON_EARLY_DATA:",
 		expectedLocalError: "remote error: illegal parameter",
 	})
+
+	// Test that the client can write early data when it has received a partial
+	// ServerHello..Finished flight. See https://crbug.com/1208784. Note the
+	// EncryptedExtensions test assumes EncryptedExtensions and Finished are in
+	// separate records, i.e. that PackHandshakeFlight is disabled.
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "EarlyData-WriteAfterServerHello",
+		config: Config{
+			MinVersion: VersionTLS13,
+			MaxVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				// Write the server response before expecting early data.
+				ExpectEarlyData:     [][]byte{},
+				ExpectLateEarlyData: [][]byte{[]byte(shimInitialWrite)},
+			},
+		},
+		resumeSession: true,
+		earlyData:     true,
+		flags: []string{
+			"-async",
+			"-on-resume-early-write-after-message",
+			strconv.Itoa(int(typeServerHello)),
+		},
+	})
+	testCases = append(testCases, testCase{
+		testType: clientTest,
+		name:     "EarlyData-WriteAfterEncryptedExtensions",
+		config: Config{
+			MinVersion: VersionTLS13,
+			MaxVersion: VersionTLS13,
+			Bugs: ProtocolBugs{
+				// Write the server response before expecting early data.
+				ExpectEarlyData:     [][]byte{},
+				ExpectLateEarlyData: [][]byte{[]byte(shimInitialWrite)},
+			},
+		},
+		resumeSession: true,
+		earlyData:     true,
+		flags: []string{
+			"-async",
+			"-on-resume-early-write-after-message",
+			strconv.Itoa(int(typeEncryptedExtensions)),
+		},
+	})
 }
 
 func addTLS13CipherPreferenceTests() {
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index fff536f..e933f0f 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -235,6 +235,7 @@
     {"-read-size", &TestConfig::read_size},
     {"-expect-ticket-age-skew", &TestConfig::expect_ticket_age_skew},
     {"-quic-use-legacy-codepoint", &TestConfig::quic_use_legacy_codepoint},
+    {"-early-write-after-message", &TestConfig::early_write_after_message},
 };
 
 const Flag<std::vector<int>> kIntVectorFlags[] = {
@@ -599,6 +600,9 @@
       char text[16];
       snprintf(text, sizeof(text), "hs %d\n", type);
       state->msg_callback_text += text;
+      if (!is_write) {
+        state->last_message_received = type;
+      }
       return;
     }
 
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index f4e3f61..f4ddba2 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -190,6 +190,7 @@
   bool expect_no_hrr = false;
   bool wait_for_debugger = false;
   std::string quic_early_data_context;
+  int early_write_after_message = 0;
 
   int argc;
   char **argv;
diff --git a/ssl/test/test_state.h b/ssl/test/test_state.h
index 2c558a4..d9fe945 100644
--- a/ssl/test/test_state.h
+++ b/ssl/test/test_state.h
@@ -68,6 +68,7 @@
   bool cert_verified = false;
   int explicit_renegotiates = 0;
   std::function<bool(const SSL_CLIENT_HELLO*)> get_handshake_hints_cb;
+  int last_message_received = -1;
 };
 
 bool SetTestState(SSL *ssl, std::unique_ptr<TestState> state);