Tidy up the PSK binder logic.

Computing the binders on ClientHelloInner is a little interesting. While
I'm in the area, tidy this up a bit. The exploded parameters may as well
be an SSL_SESSION, and hash_transcript_and_truncated_client_hello can
just get folded in.

Change-Id: I9d3a7e0ae9f391d6b9a23b51b5d7198e15569b11
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/47997
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/internal.h b/ssl/internal.h
index 72e1fba..6a1a650 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -700,9 +700,9 @@
   // the transcript. It returns true on success and false on failure. If the
   // handshake buffer is still present, |digest| may be any supported digest.
   // Otherwise, |digest| must match the transcript hash.
-  bool CopyToHashContext(EVP_MD_CTX *ctx, const EVP_MD *digest);
+  bool CopyToHashContext(EVP_MD_CTX *ctx, const EVP_MD *digest) const;
 
-  Span<const uint8_t> buffer() {
+  Span<const uint8_t> buffer() const {
     return MakeConstSpan(reinterpret_cast<const uint8_t *>(buffer_->data),
                          buffer_->length);
   }
@@ -725,14 +725,14 @@
   // GetHash writes the handshake hash to |out| which must have room for at
   // least |DigestLen| bytes. On success, it returns true and sets |*out_len| to
   // the number of bytes written. Otherwise, it returns false.
-  bool GetHash(uint8_t *out, size_t *out_len);
+  bool GetHash(uint8_t *out, size_t *out_len) const;
 
   // GetFinishedMAC computes the MAC for the Finished message into the bytes
   // pointed by |out| and writes the number of bytes to |*out_len|. |out| must
   // have room for |EVP_MAX_MD_SIZE| bytes. It returns true on success and false
   // on failure.
   bool GetFinishedMAC(uint8_t *out, size_t *out_len, const SSL_SESSION *session,
-                      bool from_server);
+                      bool from_server) const;
 
  private:
   // buffer_, if non-null, contains the handshake transcript.
@@ -1418,13 +1418,14 @@
 // tls13_write_psk_binder calculates the PSK binder value and replaces the last
 // bytes of |msg| with the resulting value. It returns true on success, and
 // false on failure.
-bool tls13_write_psk_binder(SSL_HANDSHAKE *hs, Span<uint8_t> msg);
+bool tls13_write_psk_binder(const SSL_HANDSHAKE *hs, Span<uint8_t> msg);
 
 // tls13_verify_psk_binder verifies that the handshake transcript, truncated up
 // to the binders has a valid signature using the value of |session|'s
 // resumption secret. It returns true on success, and false on failure.
-bool tls13_verify_psk_binder(SSL_HANDSHAKE *hs, SSL_SESSION *session,
-                             const SSLMessage &msg, CBS *binders);
+bool tls13_verify_psk_binder(const SSL_HANDSHAKE *hs,
+                             const SSL_SESSION *session, const SSLMessage &msg,
+                             CBS *binders);
 
 
 // Encrypted ClientHello.
diff --git a/ssl/ssl_transcript.cc b/ssl/ssl_transcript.cc
index 0bc13b9..1599c80 100644
--- a/ssl/ssl_transcript.cc
+++ b/ssl/ssl_transcript.cc
@@ -206,7 +206,8 @@
   return true;
 }
 
-bool SSLTranscript::CopyToHashContext(EVP_MD_CTX *ctx, const EVP_MD *digest) {
+bool SSLTranscript::CopyToHashContext(EVP_MD_CTX *ctx,
+                                      const EVP_MD *digest) const {
   const EVP_MD *transcript_digest = Digest();
   if (transcript_digest != nullptr &&
       EVP_MD_type(transcript_digest) == EVP_MD_type(digest)) {
@@ -237,7 +238,7 @@
   return true;
 }
 
-bool SSLTranscript::GetHash(uint8_t *out, size_t *out_len) {
+bool SSLTranscript::GetHash(uint8_t *out, size_t *out_len) const {
   ScopedEVP_MD_CTX ctx;
   unsigned len;
   if (!EVP_MD_CTX_copy_ex(ctx.get(), hash_.get()) ||
@@ -250,7 +251,7 @@
 
 bool SSLTranscript::GetFinishedMAC(uint8_t *out, size_t *out_len,
                                    const SSL_SESSION *session,
-                                   bool from_server) {
+                                   bool from_server) const {
   static const char kClientLabel[] = "client finished";
   static const char kServerLabel[] = "server finished";
   auto label = from_server
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index 9c3063f..2ac9985 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -395,29 +395,52 @@
 
 static const char kTLS13LabelPSKBinder[] = "res binder";
 
-static bool tls13_psk_binder(uint8_t *out, size_t *out_len, uint16_t version,
-                             const EVP_MD *digest, Span<const uint8_t> psk,
-                             Span<const uint8_t> context) {
+static bool tls13_psk_binder(uint8_t *out, size_t *out_len,
+                             const SSL_SESSION *session,
+                             const SSLTranscript &transcript,
+                             Span<const uint8_t> client_hello,
+                             size_t binders_len) {
+  const EVP_MD *digest = ssl_session_get_digest(session);
+
+  // Compute the binder key.
+  //
+  // TODO(davidben): Ideally we wouldn't recompute early secret and the binder
+  // key each time.
   uint8_t binder_context[EVP_MAX_MD_SIZE];
   unsigned binder_context_len;
-  if (!EVP_Digest(NULL, 0, binder_context, &binder_context_len, digest, NULL)) {
-    return false;
-  }
-
   uint8_t early_secret[EVP_MAX_MD_SIZE] = {0};
   size_t early_secret_len;
-  if (!HKDF_extract(early_secret, &early_secret_len, digest, psk.data(),
-                    psk.size(), NULL, 0)) {
+  uint8_t binder_key_buf[EVP_MAX_MD_SIZE] = {0};
+  auto binder_key = MakeSpan(binder_key_buf, EVP_MD_size(digest));
+  if (!EVP_Digest(nullptr, 0, binder_context, &binder_context_len, digest,
+                  nullptr) ||
+      !HKDF_extract(early_secret, &early_secret_len, digest, session->secret,
+                    session->secret_length, nullptr, 0) ||
+      !hkdf_expand_label(binder_key, digest,
+                         MakeConstSpan(early_secret, early_secret_len),
+                         label_to_span(kTLS13LabelPSKBinder),
+                         MakeConstSpan(binder_context, binder_context_len))) {
     return false;
   }
 
-  uint8_t binder_key_buf[EVP_MAX_MD_SIZE] = {0};
-  auto binder_key = MakeSpan(binder_key_buf, EVP_MD_size(digest));
-  if (!hkdf_expand_label(binder_key, digest,
-                         MakeConstSpan(early_secret, early_secret_len),
-                         label_to_span(kTLS13LabelPSKBinder),
-                         MakeConstSpan(binder_context, binder_context_len)) ||
-      !tls13_verify_data(out, out_len, digest, version, binder_key, context)) {
+  // Hash the transcript and truncated ClientHello.
+  if (client_hello.size() < binders_len) {
+    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+    return false;
+  }
+  auto truncated = client_hello.subspan(0, client_hello.size() - binders_len);
+  uint8_t context[EVP_MAX_MD_SIZE];
+  unsigned context_len;
+  ScopedEVP_MD_CTX ctx;
+  if (!transcript.CopyToHashContext(ctx.get(), digest) ||
+      !EVP_DigestUpdate(ctx.get(), truncated.data(),
+                        truncated.size()) ||
+      !EVP_DigestFinal_ex(ctx.get(), context, &context_len)) {
+    return false;
+  }
+
+  if (!tls13_verify_data(out, out_len, digest, session->ssl_version, binder_key,
+                         MakeConstSpan(context, context_len))) {
     return false;
   }
 
@@ -425,44 +448,19 @@
   return true;
 }
 
-static bool hash_transcript_and_truncated_client_hello(
-    SSL_HANDSHAKE *hs, uint8_t *out, size_t *out_len, const EVP_MD *digest,
-    Span<const uint8_t> client_hello, size_t binders_len) {
-  // Truncate the ClientHello.
-  if (binders_len + 2 < binders_len || client_hello.size() < binders_len + 2) {
-    return false;
-  }
-  client_hello = client_hello.subspan(0, client_hello.size() - binders_len - 2);
-
-  ScopedEVP_MD_CTX ctx;
-  unsigned len;
-  if (!hs->transcript.CopyToHashContext(ctx.get(), digest) ||
-      !EVP_DigestUpdate(ctx.get(), client_hello.data(), client_hello.size()) ||
-      !EVP_DigestFinal_ex(ctx.get(), out, &len)) {
-    return false;
-  }
-
-  *out_len = len;
-  return true;
-}
-
-bool tls13_write_psk_binder(SSL_HANDSHAKE *hs, Span<uint8_t> msg) {
-  SSL *const ssl = hs->ssl;
+bool tls13_write_psk_binder(const SSL_HANDSHAKE *hs,
+                            Span<uint8_t> msg) {
+  const SSL *const ssl = hs->ssl;
   const EVP_MD *digest = ssl_session_get_digest(ssl->session.get());
-  size_t hash_len = EVP_MD_size(digest);
-
-  ScopedEVP_MD_CTX ctx;
-  uint8_t context[EVP_MAX_MD_SIZE];
-  size_t context_len;
+  const size_t hash_len = EVP_MD_size(digest);
+  // We only offer one PSK, so the binders are a u16 and u8 length
+  // prefix, followed by the binder. The caller is assumed to have constructed
+  // |msg| with placeholder binders.
+  const size_t binders_len = 3 + hash_len;
   uint8_t verify_data[EVP_MAX_MD_SIZE];
   size_t verify_data_len;
-  if (!hash_transcript_and_truncated_client_hello(
-          hs, context, &context_len, digest, msg,
-          1 /* length prefix */ + hash_len) ||
-      !tls13_psk_binder(
-          verify_data, &verify_data_len, ssl->session->ssl_version, digest,
-          MakeConstSpan(ssl->session->secret, ssl->session->secret_length),
-          MakeConstSpan(context, context_len)) ||
+  if (!tls13_psk_binder(verify_data, &verify_data_len, ssl->session.get(),
+                        hs->transcript, msg, binders_len) ||
       verify_data_len != hash_len) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
     return false;
@@ -473,20 +471,17 @@
   return true;
 }
 
-bool tls13_verify_psk_binder(SSL_HANDSHAKE *hs, SSL_SESSION *session,
-                             const SSLMessage &msg, CBS *binders) {
-  uint8_t context[EVP_MAX_MD_SIZE];
-  size_t context_len;
+bool tls13_verify_psk_binder(const SSL_HANDSHAKE *hs,
+                             const SSL_SESSION *session, const SSLMessage &msg,
+                             CBS *binders) {
   uint8_t verify_data[EVP_MAX_MD_SIZE];
   size_t verify_data_len;
   CBS binder;
-  if (!hash_transcript_and_truncated_client_hello(hs, context, &context_len,
-                                                  hs->transcript.Digest(),
-                                                  msg.raw, CBS_len(binders)) ||
-      !tls13_psk_binder(verify_data, &verify_data_len, hs->ssl->version,
-                        hs->transcript.Digest(),
-                        MakeConstSpan(session->secret, session->secret_length),
-                        MakeConstSpan(context, context_len)) ||
+  // The binders are computed over |msg| with |binders| and its u16 length
+  // prefix removed. The caller is assumed to have parsed |msg|, extracted
+  // |binders|, and verified the PSK extension is last.
+  if (!tls13_psk_binder(verify_data, &verify_data_len, session, hs->transcript,
+                        msg.raw, 2 + CBS_len(binders)) ||
       // We only consider the first PSK, so compare against the first binder.
       !CBS_get_u8_length_prefixed(binders, &binder)) {
     OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);