Convert more of the SSL write path to size_t and Spans.

We still have our <= 0 return values because anything with BIOs tries to
preserve BIO_write's error returns. (Maybe we can stop doing this?
BIO_read's error return is a little subtle with EOF vs error, but
BIO_write's is uninteresting.) But the rest of the logic is size_t-clean
and hopefully a little clearer. We still have to support SSL_write's
rather goofy calling convention, however.

I haven't pushed Spans down into the low-level record construction logic
yet. We should probably do that, but there are enough offsets tossed
around there that they warrant their own CL.

Bug: 507
Change-Id: Ia0c702d1a2d3713e71b0bbfa8d65649d3b20da9b
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/47544
Commit-Queue: Bob Beck <bbe@google.com>
Reviewed-by: Bob Beck <bbe@google.com>
diff --git a/ssl/d1_pkt.cc b/ssl/d1_pkt.cc
index b9b0ef9..b866156 100644
--- a/ssl/d1_pkt.cc
+++ b/ssl/d1_pkt.cc
@@ -186,8 +186,8 @@
   return ssl_open_record_success;
 }
 
-int dtls1_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in,
-                         int len) {
+int dtls1_write_app_data(SSL *ssl, bool *out_needs_handshake,
+                         size_t *out_bytes_written, Span<const uint8_t> in) {
   assert(!SSL_in_init(ssl));
   *out_needs_handshake = false;
 
@@ -196,47 +196,46 @@
     return -1;
   }
 
-  if (len > SSL3_RT_MAX_PLAIN_LENGTH) {
+  // DTLS does not split the input across records.
+  if (in.size() > SSL3_RT_MAX_PLAIN_LENGTH) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_DTLS_MESSAGE_TOO_BIG);
     return -1;
   }
 
-  if (len < 0) {
-    OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_LENGTH);
-    return -1;
+  if (in.empty()) {
+    *out_bytes_written = 0;
+    return 1;
   }
 
-  if (len == 0) {
-    return 0;
-  }
-
-  int ret = dtls1_write_record(ssl, SSL3_RT_APPLICATION_DATA, in, (size_t)len,
+  int ret = dtls1_write_record(ssl, SSL3_RT_APPLICATION_DATA, in,
                                dtls1_use_current_epoch);
   if (ret <= 0) {
     return ret;
   }
-  return len;
+  *out_bytes_written = in.size();
+  return 1;
 }
 
-int dtls1_write_record(SSL *ssl, int type, const uint8_t *in, size_t len,
+int dtls1_write_record(SSL *ssl, int type, Span<const uint8_t> in,
                        enum dtls1_use_epoch_t use_epoch) {
   SSLBuffer *buf = &ssl->s3->write_buffer;
-  assert(len <= SSL3_RT_MAX_PLAIN_LENGTH);
+  assert(in.size() <= SSL3_RT_MAX_PLAIN_LENGTH);
   // There should never be a pending write buffer in DTLS. One can't write half
   // a datagram, so the write buffer is always dropped in
   // |ssl_write_buffer_flush|.
   assert(buf->empty());
 
-  if (len > SSL3_RT_MAX_PLAIN_LENGTH) {
+  if (in.size() > SSL3_RT_MAX_PLAIN_LENGTH) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return -1;
   }
 
   size_t ciphertext_len;
   if (!buf->EnsureCap(ssl_seal_align_prefix_len(ssl),
-                      len + SSL_max_seal_overhead(ssl)) ||
+                      in.size() + SSL_max_seal_overhead(ssl)) ||
       !dtls_seal_record(ssl, buf->remaining().data(), &ciphertext_len,
-                        buf->remaining().size(), type, in, len, use_epoch)) {
+                        buf->remaining().size(), type, in.data(), in.size(),
+                        use_epoch)) {
     buf->Clear();
     return -1;
   }
@@ -250,7 +249,7 @@
 }
 
 int dtls1_dispatch_alert(SSL *ssl) {
-  int ret = dtls1_write_record(ssl, SSL3_RT_ALERT, &ssl->s3->send_alert[0], 2,
+  int ret = dtls1_write_record(ssl, SSL3_RT_ALERT, ssl->s3->send_alert,
                                dtls1_use_current_epoch);
   if (ret <= 0) {
     return ret;
diff --git a/ssl/internal.h b/ssl/internal.h
index 1a78f63..0a15ace 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -2456,8 +2456,13 @@
   ssl_open_record_t (*open_app_data)(SSL *ssl, Span<uint8_t> *out,
                                      size_t *out_consumed, uint8_t *out_alert,
                                      Span<uint8_t> in);
-  int (*write_app_data)(SSL *ssl, bool *out_needs_handshake, const uint8_t *buf,
-                        int len);
+  // write_app_data encrypts and writes |in| as application data. On success, it
+  // returns one and sets |*out_bytes_written| to the number of bytes of |in|
+  // written. Otherwise, it returns <= 0 and sets |*out_needs_handshake| to
+  // whether the operation failed because the caller needs to drive the
+  // handshake.
+  int (*write_app_data)(SSL *ssl, bool *out_needs_handshake,
+                        size_t *out_bytes_written, Span<const uint8_t> in);
   int (*dispatch_alert)(SSL *ssl);
   // init_message begins a new handshake message of type |type|. |cbb| is the
   // root CBB to be passed into |finish_message|. |*body| is set to a child CBB
@@ -2646,11 +2651,23 @@
   // |read_buffer|.
   Span<uint8_t> pending_app_data;
 
-  // partial write - check the numbers match
-  unsigned int wnum = 0;  // number of bytes sent so far
-  int wpend_tot = 0;      // number bytes written
-  int wpend_type = 0;
-  const uint8_t *wpend_buf = nullptr;
+  // unreported_bytes_written is the number of bytes successfully written to the
+  // transport, but not yet reported to the caller. The next |SSL_write| will
+  // skip this many bytes from the input. This is used if
+  // |SSL_MODE_ENABLE_PARTIAL_WRITE| is disabled, in which case |SSL_write| only
+  // reports bytes written when the full caller input is written.
+  size_t unreported_bytes_written = 0;
+
+  // pending_write, if |has_pending_write| is true, is the caller-supplied data
+  // corresponding to the current pending write. This is used to check the
+  // caller retried with a compatible buffer.
+  Span<const uint8_t> pending_write;
+
+  // pending_write_type, if |has_pending_write| is true, is the record type
+  // for the current pending write.
+  //
+  // TODO(davidben): Remove this when alerts are moved out of this write path.
+  uint8_t pending_write_type = 0;
 
   // read_shutdown is the shutdown state for the read half of the connection.
   enum ssl_shutdown_t read_shutdown = ssl_shutdown_none;
@@ -3214,8 +3231,8 @@
 ssl_open_record_t tls_open_change_cipher_spec(SSL *ssl, size_t *out_consumed,
                                               uint8_t *out_alert,
                                               Span<uint8_t> in);
-int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *buf,
-                       int len);
+int tls_write_app_data(SSL *ssl, bool *out_needs_handshake,
+                       size_t *out_bytes_written, Span<const uint8_t> in);
 
 bool tls_new(SSL *ssl);
 void tls_free(SSL *ssl);
@@ -3248,11 +3265,11 @@
                                                 Span<uint8_t> in);
 
 int dtls1_write_app_data(SSL *ssl, bool *out_needs_handshake,
-                         const uint8_t *buf, int len);
+                         size_t *out_bytes_written, Span<const uint8_t> in);
 
 // dtls1_write_record sends a record. It returns one on success and <= 0 on
 // error.
-int dtls1_write_record(SSL *ssl, int type, const uint8_t *buf, size_t len,
+int dtls1_write_record(SSL *ssl, int type, Span<const uint8_t> in,
                        enum dtls1_use_epoch_t use_epoch);
 
 int dtls1_retransmit_outgoing_messages(SSL *ssl);
diff --git a/ssl/s3_pkt.cc b/ssl/s3_pkt.cc
index efe5905..bc0d13d 100644
--- a/ssl/s3_pkt.cc
+++ b/ssl/s3_pkt.cc
@@ -126,10 +126,11 @@
 
 BSSL_NAMESPACE_BEGIN
 
-static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len);
+static int do_tls_write(SSL *ssl, size_t *out_bytes_written, uint8_t type,
+                        Span<const uint8_t> in);
 
-int tls_write_app_data(SSL *ssl, bool *out_needs_handshake, const uint8_t *in,
-                       int len) {
+int tls_write_app_data(SSL *ssl, bool *out_needs_handshake,
+                       size_t *out_bytes_written, Span<const uint8_t> in) {
   assert(ssl_can_write(ssl));
   assert(!ssl->s3->aead_write_ctx->is_null_cipher());
 
@@ -140,32 +141,28 @@
     return -1;
   }
 
-  // TODO(davidben): Switch this logic to |size_t| and |bssl::Span|.
-  assert(ssl->s3->wnum <= INT_MAX);
-  unsigned tot = ssl->s3->wnum;
-
-  // Ensure that if we end up with a smaller value of data to write out than
-  // the the original len from a write which didn't complete for non-blocking
-  // I/O and also somehow ended up avoiding the check for this in
-  // do_tls_write/SSL_R_BAD_WRITE_RETRY as it must never be possible to end up
-  // with (len-tot) as a large number that will then promptly send beyond the
-  // end of the users buffer ... so we trap and report the error in a way the
-  // user will notice.
-  if (len < 0 || (size_t)len < tot) {
+  size_t total_bytes_written = ssl->s3->unreported_bytes_written;
+  if (in.size() < total_bytes_written) {
+    // This can happen if the caller disables |SSL_MODE_ENABLE_PARTIAL_WRITE|,
+    // asks us to write some input of length N, we successfully encrypt M bytes
+    // and write it, but fail to write the rest. We will report
+    // |SSL_ERROR_WANT_WRITE|. If the caller then retries with fewer than M
+    // bytes, we cannot satisfy that request. The caller is required to always
+    // retry with at least as many bytes as the previous attempt.
     OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_LENGTH);
     return -1;
   }
 
-  const int is_early_data_write =
-      !ssl->server && SSL_in_early_data(ssl) && ssl->s3->hs->can_early_write;
+  in = in.subspan(total_bytes_written);
 
-  unsigned n = len - tot;
+  const bool is_early_data_write =
+      !ssl->server && SSL_in_early_data(ssl) && ssl->s3->hs->can_early_write;
   for (;;) {
     size_t max_send_fragment = ssl->max_send_fragment;
     if (is_early_data_write) {
       SSL_HANDSHAKE *hs = ssl->s3->hs.get();
       if (hs->early_data_written >= hs->early_session->ticket_max_early_data) {
-        ssl->s3->wnum = tot;
+        ssl->s3->unreported_bytes_written = total_bytes_written;
         hs->can_early_write = false;
         *out_needs_handshake = true;
         return -1;
@@ -175,35 +172,43 @@
                                     hs->early_data_written});
     }
 
-    const size_t nw = std::min(max_send_fragment, size_t{n});
-    int ret = do_tls_write(ssl, SSL3_RT_APPLICATION_DATA, &in[tot], nw);
+    const size_t to_write = std::min(max_send_fragment, in.size());
+    size_t bytes_written;
+    int ret = do_tls_write(ssl, &bytes_written, SSL3_RT_APPLICATION_DATA,
+                           in.subspan(0, to_write));
     if (ret <= 0) {
-      ssl->s3->wnum = tot;
+      ssl->s3->unreported_bytes_written = total_bytes_written;
       return ret;
     }
 
+    // Note |bytes_written| may be less than |to_write| if there was a pending
+    // record from a smaller write attempt.
+    assert(bytes_written <= to_write);
+    total_bytes_written += bytes_written;
+    in = in.subspan(bytes_written);
     if (is_early_data_write) {
-      ssl->s3->hs->early_data_written += ret;
+      ssl->s3->hs->early_data_written += bytes_written;
     }
 
-    if (ret == (int)n || (ssl->mode & SSL_MODE_ENABLE_PARTIAL_WRITE)) {
-      ssl->s3->wnum = 0;
-      return tot + ret;
+    if (in.empty() || (ssl->mode & SSL_MODE_ENABLE_PARTIAL_WRITE)) {
+      ssl->s3->unreported_bytes_written = 0;
+      *out_bytes_written = total_bytes_written;
+      return 1;
     }
-
-    n -= ret;
-    tot += ret;
   }
 }
 
-// do_tls_write writes an SSL record of the given type.
-static int do_tls_write(SSL *ssl, int type, const uint8_t *in, unsigned len) {
+// do_tls_write writes an SSL record of the given type. On success, it sets
+// |*out_bytes_written| to number of bytes successfully written and returns one.
+// On error, it returns a value <= 0 from the underlying |BIO|.
+static int do_tls_write(SSL *ssl, size_t *out_bytes_written, uint8_t type,
+                        Span<const uint8_t> in) {
   // If there is a pending write, the retry must be consistent.
-  if (ssl->s3->wpend_tot > 0 &&
-      (ssl->s3->wpend_tot > (int)len ||
+  if (!ssl->s3->pending_write.empty() &&
+      (ssl->s3->pending_write.size() > in.size() ||
        (!(ssl->mode & SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER) &&
-        ssl->s3->wpend_buf != in) ||
-       ssl->s3->wpend_type != type)) {
+        ssl->s3->pending_write.data() != in.data()) ||
+       ssl->s3->pending_write_type != type)) {
     OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_WRITE_RETRY);
     return -1;
   }
@@ -216,15 +221,14 @@
   }
 
   // If there is a pending write, we just completed it. Report it to the caller.
-  if (ssl->s3->wpend_tot > 0) {
-    ret = ssl->s3->wpend_tot;
-    ssl->s3->wpend_buf = nullptr;
-    ssl->s3->wpend_tot = 0;
-    return ret;
+  if (!ssl->s3->pending_write.empty()) {
+    *out_bytes_written = ssl->s3->pending_write.size();
+    ssl->s3->pending_write = {};
+    return 1;
   }
 
   SSLBuffer *buf = &ssl->s3->write_buffer;
-  if (len > SSL3_RT_MAX_PLAIN_LENGTH || buf->size() > 0) {
+  if (in.size() > SSL3_RT_MAX_PLAIN_LENGTH || buf->size() > 0) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return -1;
   }
@@ -233,16 +237,22 @@
     return -1;
   }
 
-  size_t flight_len = 0;
+  // We may have unflushed handshake data that must be written before |in|. This
+  // may be a KeyUpdate acknowledgment, 0-RTT key change messages, or a
+  // NewSessionTicket.
+  Span<const uint8_t> pending_flight;
   if (ssl->s3->pending_flight != nullptr) {
-    flight_len =
-        ssl->s3->pending_flight->length - ssl->s3->pending_flight_offset;
+    pending_flight = MakeConstSpan(
+        reinterpret_cast<const uint8_t *>(ssl->s3->pending_flight->data),
+        ssl->s3->pending_flight->length);
+    pending_flight = pending_flight.subspan(ssl->s3->pending_flight_offset);
   }
 
-  size_t max_out = flight_len;
-  if (len > 0) {
-    const size_t max_ciphertext_len = len + SSL_max_seal_overhead(ssl);
-    if (max_ciphertext_len < len || max_out + max_ciphertext_len < max_out) {
+  size_t max_out = pending_flight.size();
+  if (!in.empty()) {
+    const size_t max_ciphertext_len = in.size() + SSL_max_seal_overhead(ssl);
+    if (max_ciphertext_len < in.size() ||
+        max_out + max_ciphertext_len < max_out) {
       OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
       return -1;
     }
@@ -250,31 +260,29 @@
   }
 
   if (max_out == 0) {
-    return 0;
+    // Nothing to write.
+    *out_bytes_written = 0;
+    return 1;
   }
 
-  if (!buf->EnsureCap(flight_len + ssl_seal_align_prefix_len(ssl), max_out)) {
+  if (!buf->EnsureCap(pending_flight.size() + ssl_seal_align_prefix_len(ssl),
+                      max_out)) {
     return -1;
   }
 
-  // Add any unflushed handshake data as a prefix. This may be a KeyUpdate
-  // acknowledgment or 0-RTT key change messages. |pending_flight| must be clear
-  // when data is added to |write_buffer| or it will be written in the wrong
-  // order.
-  if (ssl->s3->pending_flight != nullptr) {
-    OPENSSL_memcpy(
-        buf->remaining().data(),
-        ssl->s3->pending_flight->data + ssl->s3->pending_flight_offset,
-        flight_len);
+  // Copy |pending_flight| to the output.
+  if (!pending_flight.empty()) {
+    OPENSSL_memcpy(buf->remaining().data(), pending_flight.data(),
+                   pending_flight.size());
     ssl->s3->pending_flight.reset();
     ssl->s3->pending_flight_offset = 0;
-    buf->DidWrite(flight_len);
+    buf->DidWrite(pending_flight.size());
   }
 
-  if (len > 0) {
+  if (!in.empty()) {
     size_t ciphertext_len;
     if (!tls_seal_record(ssl, buf->remaining().data(), &ciphertext_len,
-                         buf->remaining().size(), type, in, len)) {
+                         buf->remaining().size(), type, in.data(), in.size())) {
       return -1;
     }
     buf->DidWrite(ciphertext_len);
@@ -288,15 +296,15 @@
   ret = ssl_write_buffer_flush(ssl);
   if (ret <= 0) {
     // Track the unfinished write.
-    if (len > 0) {
-      ssl->s3->wpend_tot = len;
-      ssl->s3->wpend_buf = in;
-      ssl->s3->wpend_type = type;
+    if (!in.empty()) {
+      ssl->s3->pending_write = in;
+      ssl->s3->pending_write_type = type;
     }
     return ret;
   }
 
-  return len;
+  *out_bytes_written = in.size();
+  return 1;
 }
 
 ssl_open_record_t tls_open_app_data(SSL *ssl, Span<uint8_t> *out,
@@ -434,10 +442,13 @@
       return 0;
     }
   } else {
-    int ret = do_tls_write(ssl, SSL3_RT_ALERT, &ssl->s3->send_alert[0], 2);
+    size_t bytes_written;
+    int ret =
+        do_tls_write(ssl, &bytes_written, SSL3_RT_ALERT, ssl->s3->send_alert);
     if (ret <= 0) {
       return ret;
     }
+    assert(bytes_written == 2);
   }
 
   ssl->s3->alert_dispatch = false;
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index 7035748..4d3ad44 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -1058,6 +1058,7 @@
   }
 
   int ret = 0;
+  size_t bytes_written = 0;
   bool needs_handshake = false;
   do {
     // If necessary, complete the handshake implicitly.
@@ -1072,10 +1073,16 @@
       }
     }
 
-    ret = ssl->method->write_app_data(ssl, &needs_handshake,
-                                      (const uint8_t *)buf, num);
+    if (num < 0) {
+      OPENSSL_PUT_ERROR(SSL, SSL_R_BAD_LENGTH);
+      return -1;
+    }
+    ret = ssl->method->write_app_data(
+        ssl, &needs_handshake, &bytes_written,
+        MakeConstSpan(static_cast<const uint8_t *>(buf),
+                      static_cast<size_t>(num)));
   } while (needs_handshake);
-  return ret;
+  return ret <= 0 ? ret : static_cast<int>(bytes_written);
 }
 
 int SSL_key_update(SSL *ssl, int request_type) {
@@ -1239,8 +1246,7 @@
   // Discard any unfinished writes from the perspective of |SSL_write|'s
   // retry. The handshake will transparently flush out the pending record
   // (discarded by the server) to keep the framing correct.
-  ssl->s3->wpend_buf = nullptr;
-  ssl->s3->wpend_tot = 0;
+  ssl->s3->pending_write = {};
 }
 
 enum ssl_early_data_reason_t SSL_get_early_data_reason(const SSL *ssl) {