Fix an edge case in SSL_write's retry mechanism.

This is split out from
https://boringssl-review.googlesource.com/c/boringssl/+/47544 just to
get the bugfixes and tests out of the way of the refactor.

If we trip the SSL_R_BAD_LENGTH check in tls_write_app_data, wnum is set
to zero. But wnum should only be cleared on a successful write. It
tracks the number of input bytes that have been written to the transport
but not yet reported to the caller. Instead, move it to the success
return in that function. All the other error paths already set it to
something else.

Change-Id: Ib22f9cf04454ecdb0062077f183be5070ab7d791
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/53545
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: Bob Beck <bbe@google.com>
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index 450f7dc..bc3e99a 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -143,7 +143,6 @@
   // TODO(davidben): Switch this logic to |size_t| and |bssl::Span|.
   assert(ssl->s3->wnum <= INT_MAX);
   unsigned tot = ssl->s3->wnum;
-  ssl->s3->wnum = 0;
 
   // Ensure that if we end up with a smaller value of data to write out than
   // the the original len from a write which didn't complete for non-blocking
@@ -188,6 +187,7 @@
     }
 
     if (ret == (int)n || (ssl->mode & SSL_MODE_ENABLE_PARTIAL_WRITE)) {
+      ssl->s3->wnum = 0;
       return tot + ret;
     }
 
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 6bc015c..0811cb4 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -3904,18 +3904,18 @@
                           {cert_.get(), cert_.get()}));
 }
 
-static bool ExpectBadWriteRetry() {
+static bool ExpectSingleError(int lib, int reason) {
+  const char *expected = ERR_reason_error_string(ERR_PACK(lib, reason));
   int err = ERR_get_error();
-  if (ERR_GET_LIB(err) != ERR_LIB_SSL ||
-      ERR_GET_REASON(err) != SSL_R_BAD_WRITE_RETRY) {
+  if (ERR_GET_LIB(err) != lib || ERR_GET_REASON(err) != reason) {
     char buf[ERR_ERROR_STRING_BUF_LEN];
     ERR_error_string_n(err, buf, sizeof(buf));
-    fprintf(stderr, "Wanted SSL_R_BAD_WRITE_RETRY, got: %s.\n", buf);
+    fprintf(stderr, "Wanted %s, got: %s.\n", expected, buf);
     return false;
   }
 
   if (ERR_peek_error() != 0) {
-    fprintf(stderr, "Unexpected error following SSL_R_BAD_WRITE_RETRY.\n");
+    fprintf(stderr, "Unexpected error following %s.\n", expected);
     return false;
   }
 
@@ -3931,8 +3931,6 @@
     SCOPED_TRACE(enable_partial_write);
 
     // Connect a client and server.
-    ASSERT_TRUE(UseCertAndKey(client_ctx_.get()));
-
     ASSERT_TRUE(Connect());
 
     if (enable_partial_write) {
@@ -3951,9 +3949,7 @@
         ASSERT_EQ(SSL_get_error(client_.get(), ret), SSL_ERROR_WANT_WRITE);
         break;
       }
-
       ASSERT_EQ(ret, 5);
-
       count++;
     }
 
@@ -3966,14 +3962,14 @@
     ASSERT_EQ(SSL_get_error(client_.get(),
                             SSL_write(client_.get(), data, kChunkLen - 1)),
               SSL_ERROR_SSL);
-    ASSERT_TRUE(ExpectBadWriteRetry());
+    ASSERT_TRUE(ExpectSingleError(ERR_LIB_SSL, SSL_R_BAD_WRITE_RETRY));
 
     // Retrying with a different buffer pointer is not legal.
     char data2[] = "hello";
     ASSERT_EQ(SSL_get_error(client_.get(),
                             SSL_write(client_.get(), data2, kChunkLen)),
               SSL_ERROR_SSL);
-    ASSERT_TRUE(ExpectBadWriteRetry());
+    ASSERT_TRUE(ExpectSingleError(ERR_LIB_SSL, SSL_R_BAD_WRITE_RETRY));
 
     // With |SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER|, the buffer may move.
     SSL_set_mode(client_.get(), SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER);
@@ -3985,7 +3981,7 @@
     ASSERT_EQ(SSL_get_error(client_.get(),
                             SSL_write(client_.get(), data2, kChunkLen - 1)),
               SSL_ERROR_SSL);
-    ASSERT_TRUE(ExpectBadWriteRetry());
+    ASSERT_TRUE(ExpectSingleError(ERR_LIB_SSL, SSL_R_BAD_WRITE_RETRY));
 
     // Retrying with a larger buffer is legal.
     ASSERT_EQ(SSL_get_error(client_.get(),
@@ -4003,20 +3999,94 @@
     // pending record, skip over that many bytes of input (on assumption they
     // are the same), and write the remainder. If SSL_MODE_ENABLE_PARTIAL_WRITE
     // is set, this will complete in two steps.
-    char data3[] = "_____!";
+    char data_longer[] = "_____!!!!!";
     if (enable_partial_write) {
-      ASSERT_EQ(SSL_write(client_.get(), data3, kChunkLen + 1), kChunkLen);
-      ASSERT_EQ(SSL_write(client_.get(), data3 + kChunkLen, 1), 1);
+      ASSERT_EQ(SSL_write(client_.get(), data_longer, 2 * kChunkLen),
+                kChunkLen);
+      ASSERT_EQ(SSL_write(client_.get(), data_longer + kChunkLen, kChunkLen),
+                kChunkLen);
     } else {
-      ASSERT_EQ(SSL_write(client_.get(), data3, kChunkLen + 1), kChunkLen + 1);
+      ASSERT_EQ(SSL_write(client_.get(), data_longer, 2 * kChunkLen),
+                2 * kChunkLen);
     }
 
     // Check the last write was correct. The data will be spread over two
     // records, so SSL_read returns twice.
     ASSERT_EQ(SSL_read(server_.get(), buf, sizeof(buf)), kChunkLen);
     ASSERT_EQ(OPENSSL_memcmp(buf, "hello", kChunkLen), 0);
-    ASSERT_EQ(SSL_read(server_.get(), buf, sizeof(buf)), 1);
-    ASSERT_EQ(buf[0], '!');
+    ASSERT_EQ(SSL_read(server_.get(), buf, sizeof(buf)), kChunkLen);
+    ASSERT_EQ(OPENSSL_memcmp(buf, "!!!!!", kChunkLen), 0);
+
+    // Fill the transport buffer again. This time only leave room for one
+    // record.
+    count = 0;
+    for (;;) {
+      int ret = SSL_write(client_.get(), data, kChunkLen);
+      if (ret <= 0) {
+        ASSERT_EQ(SSL_get_error(client_.get(), ret), SSL_ERROR_WANT_WRITE);
+        break;
+      }
+      ASSERT_EQ(ret, 5);
+      count++;
+    }
+    ASSERT_EQ(SSL_read(server_.get(), buf, sizeof(buf)), kChunkLen);
+    ASSERT_EQ(OPENSSL_memcmp(buf, "hello", kChunkLen), 0);
+    count--;
+
+    // Retry the last write, with a longer input. The first half is the most
+    // recently failed write, from filling the buffer. |SSL_write| should write
+    // that to the transport, and then attempt to write the second half.
+    int ret = SSL_write(client_.get(), data_longer, 2 * kChunkLen);
+    if (enable_partial_write) {
+      // If partial writes are allowed, the write will succeed partially.
+      ASSERT_EQ(ret, kChunkLen);
+
+      // Check the first half and make room for another record.
+      ASSERT_EQ(SSL_read(server_.get(), buf, sizeof(buf)), kChunkLen);
+      ASSERT_EQ(OPENSSL_memcmp(buf, "hello", kChunkLen), 0);
+      count--;
+
+      // Finish writing the input.
+      ASSERT_EQ(SSL_write(client_.get(), data_longer + kChunkLen, kChunkLen),
+                kChunkLen);
+    } else {
+      // Otherwise, although the first half made it to the transport, the second
+      // half is blocked.
+      ASSERT_EQ(ret, -1);
+      ASSERT_EQ(SSL_get_error(client_.get(), -1), SSL_ERROR_WANT_WRITE);
+
+      // Check the first half and make room for another record.
+      ASSERT_EQ(SSL_read(server_.get(), buf, sizeof(buf)), kChunkLen);
+      ASSERT_EQ(OPENSSL_memcmp(buf, "hello", kChunkLen), 0);
+      count--;
+
+      // Retrying with fewer bytes than previously attempted is an error. If the
+      // input length is less than the number of bytes successfully written, the
+      // check happens at a different point, with a different error.
+      //
+      // TODO(davidben): Should these cases use the same error?
+      ASSERT_EQ(
+          SSL_get_error(client_.get(),
+                        SSL_write(client_.get(), data_longer, kChunkLen - 1)),
+          SSL_ERROR_SSL);
+      ASSERT_TRUE(ExpectSingleError(ERR_LIB_SSL, SSL_R_BAD_LENGTH));
+
+      // Complete the write with the correct retry.
+      ASSERT_EQ(SSL_write(client_.get(), data_longer, 2 * kChunkLen),
+                2 * kChunkLen);
+    }
+
+    // Drain the input and ensure everything was written correctly.
+    for (unsigned i = 0; i < count; i++) {
+      ASSERT_EQ(SSL_read(server_.get(), buf, sizeof(buf)), kChunkLen);
+      ASSERT_EQ(OPENSSL_memcmp(buf, "hello", kChunkLen), 0);
+    }
+
+    // The final write is spread over two records.
+    ASSERT_EQ(SSL_read(server_.get(), buf, sizeof(buf)), kChunkLen);
+    ASSERT_EQ(OPENSSL_memcmp(buf, "hello", kChunkLen), 0);
+    ASSERT_EQ(SSL_read(server_.get(), buf, sizeof(buf)), kChunkLen);
+    ASSERT_EQ(OPENSSL_memcmp(buf, "!!!!!", kChunkLen), 0);
   }
 }