runner: Construct finishedHash earlier.

We currently construct finishedHash fairly late, after we've resolved
HelloRetryRequest. As a result, we need to defer some of the transcript
operations across a large chunk of code.

This is a remnant of earlier iterations of TLS 1.3, when
HelloRetryRequest didn't tell us the cipher suite yet. Now the cipher
suite is known earlier and we can construct the finishedHash object
immediately. In doing so, move HRR handling inside doTLS13Handshake().

This keeps more of TLS 1.3 bits together and allows us to maintain the
HRR bits of the handshake closer to the rest of HRR processing. This
will be useful for ECH which complicates this part of the process with
an inner and outer ClientHello. Finally, this adds a missing check that
the HRR and SH cipher suites match.

Change-Id: Iec149eb5c648973325b190f8a0622c9196bf3a29
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/46630
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index d368701..962737a 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -214,7 +214,7 @@
 		hello.compressedCertAlgs = []uint16{1, 1}
 	} else if len(c.config.CertCompressionAlgs) > 0 {
 		hello.compressedCertAlgs = make([]uint16, 0, len(c.config.CertCompressionAlgs))
-		for id, _ := range c.config.CertCompressionAlgs {
+		for id := range c.config.CertCompressionAlgs {
 			hello.compressedCertAlgs = append(hello.compressedCertAlgs, uint16(id))
 		}
 	}
@@ -223,7 +223,7 @@
 		hello.secureRenegotiation = nil
 	}
 
-	for protocol, _ := range c.config.ApplicationSettings {
+	for protocol := range c.config.ApplicationSettings {
 		hello.alpsProtocols = append(hello.alpsProtocols, protocol)
 	}
 
@@ -468,7 +468,6 @@
 		}
 	}
 
-	var helloBytes []byte
 	if c.config.Bugs.SendV2ClientHello {
 		hello.isV2ClientHello = true
 
@@ -493,10 +492,9 @@
 			}
 		}
 
-		helloBytes = hello.marshal()
-		c.writeV2Record(helloBytes)
+		c.writeV2Record(hello.marshal())
 	} else {
-		helloBytes = hello.marshal()
+		helloBytes := hello.marshal()
 		var appendToHello byte
 		if c.config.Bugs.PartialClientFinishedWithClientHello {
 			appendToHello = typeFinished
@@ -529,7 +527,7 @@
 	if sendEarlyData {
 		finishedHash := newFinishedHash(session.wireVersion, c.isDTLS, session.cipherSuite)
 		finishedHash.addEntropy(session.secret)
-		finishedHash.Write(helloBytes)
+		finishedHash.Write(hello.marshal())
 
 		if !c.config.Bugs.SkipChangeCipherSpec {
 			c.wireVersion = session.wireVersion
@@ -567,8 +565,7 @@
 
 			hello.raw = nil
 			hello.cookie = helloVerifyRequest.cookie
-			helloBytes = hello.marshal()
-			c.writeRecord(recordTypeHandshake, helloBytes)
+			c.writeRecord(recordTypeHandshake, hello.marshal())
 			c.flushHandshake()
 
 			if err := c.simulatePacketLoss(nil); err != nil {
@@ -581,12 +578,16 @@
 		}
 	}
 
-	var serverWireVersion uint16
+	// The first message is either ServerHello or HelloRetryRequest, either of
+	// which determines the version and cipher suite.
+	var serverWireVersion, suiteID uint16
 	switch m := msg.(type) {
 	case *helloRetryRequestMsg:
 		serverWireVersion = m.vers
+		suiteID = m.cipherSuite
 	case *serverHelloMsg:
 		serverWireVersion = m.vers
+		suiteID = m.cipherSuite
 	default:
 		c.sendAlert(alertUnexpectedMessage)
 		return fmt.Errorf("tls: received unexpected message of type %T when waiting for HelloRetryRequest or ServerHello", msg)
@@ -608,185 +609,33 @@
 		return errors.New("tls: server selected SSL 3.0")
 	}
 
-	if c.vers >= VersionTLS13 {
-		// The first server message must be followed by a ChangeCipherSpec.
-		c.expectTLS13ChangeCipherSpec = true
-	}
-
-	helloRetryRequest, haveHelloRetryRequest := msg.(*helloRetryRequestMsg)
-	var secondHelloBytes []byte
-	if haveHelloRetryRequest {
-		if c.config.Bugs.FailIfHelloRetryRequested {
-			return errors.New("tls: unexpected HelloRetryRequest")
-		}
-		// Explicitly read the ChangeCipherSpec now; it should
-		// be attached to the first flight, not the second flight.
-		if err := c.readTLS13ChangeCipherSpec(); err != nil {
-			return err
-		}
-
-		c.out.resetCipher()
-		if len(helloRetryRequest.cookie) > 0 {
-			hello.tls13Cookie = helloRetryRequest.cookie
-		}
-
-		if c.config.Bugs.MisinterpretHelloRetryRequestCurve != 0 {
-			helloRetryRequest.hasSelectedGroup = true
-			helloRetryRequest.selectedGroup = c.config.Bugs.MisinterpretHelloRetryRequestCurve
-		}
-		if helloRetryRequest.hasSelectedGroup {
-			var hrrCurveFound bool
-			group := helloRetryRequest.selectedGroup
-			for _, curveID := range hello.supportedCurves {
-				if group == curveID {
-					hrrCurveFound = true
-					break
-				}
-			}
-			if !hrrCurveFound || keyShares[group] != nil {
-				c.sendAlert(alertHandshakeFailure)
-				return errors.New("tls: received invalid HelloRetryRequest")
-			}
-			curve, ok := curveForCurveID(group, c.config)
-			if !ok {
-				return errors.New("tls: Unable to get curve requested in HelloRetryRequest")
-			}
-			publicKey, err := curve.offer(c.config.rand())
-			if err != nil {
-				return err
-			}
-			keyShares[group] = curve
-			hello.keyShares = []keyShareEntry{{
-				group:       group,
-				keyExchange: publicKey,
-			}}
-		}
-
-		if c.config.Bugs.SecondClientHelloMissingKeyShare {
-			hello.hasKeyShares = false
-		}
-
-		hello.hasEarlyData = c.config.Bugs.SendEarlyDataOnSecondClientHello
-		// The first ClientHello may have skipped this due to OnlyCorruptSecondPSKBinder.
-		if c.config.Bugs.PSKBinderFirst && c.config.Bugs.OnlyCorruptSecondPSKBinder {
-			hello.prefixExtensions = append(hello.prefixExtensions, extensionPreSharedKey)
-		}
-		if c.config.Bugs.OmitPSKsOnSecondClientHello {
-			hello.pskIdentities = nil
-			hello.pskBinders = nil
-		}
-		hello.raw = nil
-
-		if len(hello.pskIdentities) > 0 {
-			generatePSKBinders(c.wireVersion, hello, session, helloBytes, helloRetryRequest.marshal(), c.config)
-		}
-		secondHelloBytes = hello.marshal()
-		secondHelloBytesToWrite := secondHelloBytes
-
-		if c.config.Bugs.PartialSecondClientHelloAfterFirst {
-			// The first byte has already been sent.
-			secondHelloBytesToWrite = secondHelloBytesToWrite[1:]
-		}
-
-		if c.config.Bugs.InterleaveEarlyData {
-			c.sendFakeEarlyData(4)
-			c.writeRecord(recordTypeHandshake, secondHelloBytesToWrite[:16])
-			c.sendFakeEarlyData(4)
-			c.writeRecord(recordTypeHandshake, secondHelloBytesToWrite[16:])
-		} else if c.config.Bugs.PartialClientFinishedWithSecondClientHello {
-			toWrite := make([]byte, len(secondHelloBytesToWrite)+1)
-			copy(toWrite, secondHelloBytesToWrite)
-			toWrite[len(secondHelloBytesToWrite)] = typeFinished
-			c.writeRecord(recordTypeHandshake, toWrite)
-		} else {
-			c.writeRecord(recordTypeHandshake, secondHelloBytesToWrite)
-		}
-		c.flushHandshake()
-
-		if c.config.Bugs.SendEarlyDataOnSecondClientHello {
-			c.sendFakeEarlyData(4)
-		}
-
-		msg, err = c.readHandshake()
-		if err != nil {
-			return err
-		}
-	}
-
-	serverHello, ok := msg.(*serverHelloMsg)
-	if !ok {
-		c.sendAlert(alertUnexpectedMessage)
-		return unexpectedMessageError(serverHello, msg)
-	}
-
-	if serverWireVersion != serverHello.vers {
-		c.sendAlert(alertIllegalParameter)
-		return fmt.Errorf("tls: server sent non-matching version %x vs %x", serverWireVersion, serverHello.vers)
-	}
-
-	_, supportsTLS13 := c.config.isSupportedVersion(VersionTLS13, false)
-	// Check for downgrade signals in the server random, per RFC 8446, section 4.1.3.
-	gotDowngrade := serverHello.random[len(serverHello.random)-8:]
-	if supportsTLS13 && !c.config.Bugs.IgnoreTLS13DowngradeRandom {
-		if c.vers <= VersionTLS12 && c.config.maxVersion(c.isDTLS) >= VersionTLS13 {
-			if bytes.Equal(gotDowngrade, downgradeTLS13) {
-				c.sendAlert(alertProtocolVersion)
-				return errors.New("tls: downgrade from TLS 1.3 detected")
-			}
-		}
-		if c.vers <= VersionTLS11 && c.config.maxVersion(c.isDTLS) >= VersionTLS12 {
-			if bytes.Equal(gotDowngrade, downgradeTLS12) {
-				c.sendAlert(alertProtocolVersion)
-				return errors.New("tls: downgrade from TLS 1.2 detected")
-			}
-		}
-	}
-
-	if bytes.Equal(gotDowngrade, downgradeJDK11) != c.config.Bugs.ExpectJDK11DowngradeRandom {
-		c.sendAlert(alertProtocolVersion)
-		if c.config.Bugs.ExpectJDK11DowngradeRandom {
-			return errors.New("tls: server did not send a JDK 11 downgrade signal")
-		}
-		return errors.New("tls: server sent an unexpected JDK 11 downgrade signal")
-	}
-
-	suite := mutualCipherSuite(hello.cipherSuites, serverHello.cipherSuite)
+	suite := mutualCipherSuite(hello.cipherSuites, suiteID)
 	if suite == nil {
 		c.sendAlert(alertHandshakeFailure)
 		return fmt.Errorf("tls: server selected an unsupported cipher suite")
 	}
 
-	if haveHelloRetryRequest && helloRetryRequest.hasSelectedGroup && helloRetryRequest.selectedGroup != serverHello.keyShare.group {
-		c.sendAlert(alertHandshakeFailure)
-		return errors.New("tls: ServerHello parameters did not match HelloRetryRequest")
-	}
-
-	if c.config.Bugs.ExpectOmitExtensions && !serverHello.omitExtensions {
-		return errors.New("tls: ServerHello did not omit extensions")
-	}
-
 	hs := &clientHandshakeState{
 		c:            c,
-		serverHello:  serverHello,
 		hello:        hello,
 		suite:        suite,
 		finishedHash: newFinishedHash(c.wireVersion, c.isDTLS, suite),
 		keyShares:    keyShares,
 		session:      session,
 	}
-
-	hs.finishedHash.WriteHandshake(helloBytes, hs.c.sendHandshakeSeq-1)
-	if haveHelloRetryRequest {
-		hs.finishedHash.UpdateForHelloRetryRequest()
-		hs.writeServerHash(helloRetryRequest.marshal())
-		hs.writeClientHash(secondHelloBytes)
-	}
+	hs.finishedHash.WriteHandshake(hello.marshal(), hs.c.sendHandshakeSeq-1)
 
 	if c.vers >= VersionTLS13 {
-		if err := hs.doTLS13Handshake(); err != nil {
+		if err := hs.doTLS13Handshake(msg); err != nil {
 			return err
 		}
 	} else {
+		hs.serverHello, ok = msg.(*serverHelloMsg)
+		if !ok {
+			c.sendAlert(alertUnexpectedMessage)
+			return unexpectedMessageError(hs.serverHello, msg)
+		}
+
 		hs.writeServerHash(hs.serverHello.marshal())
 		if c.config.Bugs.EarlyChangeCipherSpec > 0 {
 			hs.establishKeys()
@@ -798,7 +647,7 @@
 			return errors.New("tls: server selected unsupported compression format")
 		}
 
-		err = hs.processServerExtensions(&serverHello.extensions)
+		err = hs.processServerExtensions(&hs.serverHello.extensions)
 		if err != nil {
 			return err
 		}
@@ -873,9 +722,143 @@
 	return nil
 }
 
-func (hs *clientHandshakeState) doTLS13Handshake() error {
+func (hs *clientHandshakeState) doTLS13Handshake(msg interface{}) error {
 	c := hs.c
 
+	// Once the PRF hash is known, TLS 1.3 does not require a handshake buffer.
+	hs.finishedHash.discardHandshakeBuffer()
+
+	// The first server message must be followed by a ChangeCipherSpec.
+	c.expectTLS13ChangeCipherSpec = true
+
+	// The first message may be a ServerHello or HelloRetryRequest.
+	helloRetryRequest, haveHelloRetryRequest := msg.(*helloRetryRequestMsg)
+	if haveHelloRetryRequest {
+		hs.finishedHash.UpdateForHelloRetryRequest()
+		hs.writeServerHash(helloRetryRequest.marshal())
+
+		if c.config.Bugs.FailIfHelloRetryRequested {
+			return errors.New("tls: unexpected HelloRetryRequest")
+		}
+		// Explicitly read the ChangeCipherSpec now; it should
+		// be attached to the first flight, not the second flight.
+		if err := c.readTLS13ChangeCipherSpec(); err != nil {
+			return err
+		}
+
+		// Reset the encryption state, in case we sent 0-RTT data.
+		c.out.resetCipher()
+
+		firstHelloBytes := hs.hello.marshal()
+		if len(helloRetryRequest.cookie) > 0 {
+			hs.hello.tls13Cookie = helloRetryRequest.cookie
+		}
+
+		if c.config.Bugs.MisinterpretHelloRetryRequestCurve != 0 {
+			helloRetryRequest.hasSelectedGroup = true
+			helloRetryRequest.selectedGroup = c.config.Bugs.MisinterpretHelloRetryRequestCurve
+		}
+		if helloRetryRequest.hasSelectedGroup {
+			var hrrCurveFound bool
+			group := helloRetryRequest.selectedGroup
+			for _, curveID := range hs.hello.supportedCurves {
+				if group == curveID {
+					hrrCurveFound = true
+					break
+				}
+			}
+			if !hrrCurveFound || hs.keyShares[group] != nil {
+				c.sendAlert(alertHandshakeFailure)
+				return errors.New("tls: received invalid HelloRetryRequest")
+			}
+			curve, ok := curveForCurveID(group, c.config)
+			if !ok {
+				return errors.New("tls: Unable to get curve requested in HelloRetryRequest")
+			}
+			publicKey, err := curve.offer(c.config.rand())
+			if err != nil {
+				return err
+			}
+			hs.keyShares[group] = curve
+			hs.hello.keyShares = []keyShareEntry{{
+				group:       group,
+				keyExchange: publicKey,
+			}}
+		}
+
+		if c.config.Bugs.SecondClientHelloMissingKeyShare {
+			hs.hello.hasKeyShares = false
+		}
+
+		hs.hello.hasEarlyData = c.config.Bugs.SendEarlyDataOnSecondClientHello
+		// The first ClientHello may have skipped this due to OnlyCorruptSecondPSKBinder.
+		if c.config.Bugs.PSKBinderFirst && c.config.Bugs.OnlyCorruptSecondPSKBinder {
+			hs.hello.prefixExtensions = append(hs.hello.prefixExtensions, extensionPreSharedKey)
+		}
+		if c.config.Bugs.OmitPSKsOnSecondClientHello {
+			hs.hello.pskIdentities = nil
+			hs.hello.pskBinders = nil
+		}
+		hs.hello.raw = nil
+
+		if len(hs.hello.pskIdentities) > 0 {
+			generatePSKBinders(c.wireVersion, hs.hello, hs.session, firstHelloBytes, helloRetryRequest.marshal(), c.config)
+		}
+		hs.writeClientHash(hs.hello.marshal())
+		toWrite := hs.hello.marshal()
+
+		if c.config.Bugs.PartialSecondClientHelloAfterFirst {
+			// The first byte has already been sent.
+			toWrite = toWrite[1:]
+		}
+
+		if c.config.Bugs.InterleaveEarlyData {
+			c.sendFakeEarlyData(4)
+			c.writeRecord(recordTypeHandshake, toWrite[:16])
+			c.sendFakeEarlyData(4)
+			c.writeRecord(recordTypeHandshake, toWrite[16:])
+		} else if c.config.Bugs.PartialClientFinishedWithSecondClientHello {
+			toWrite = append(make([]byte, 0, len(toWrite)+1), toWrite...)
+			toWrite = append(toWrite, typeFinished)
+			c.writeRecord(recordTypeHandshake, toWrite)
+		} else {
+			c.writeRecord(recordTypeHandshake, toWrite)
+		}
+		c.flushHandshake()
+
+		if c.config.Bugs.SendEarlyDataOnSecondClientHello {
+			c.sendFakeEarlyData(4)
+		}
+
+		var err error
+		msg, err = c.readHandshake()
+		if err != nil {
+			return err
+		}
+	}
+
+	var ok bool
+	hs.serverHello, ok = msg.(*serverHelloMsg)
+	if !ok {
+		c.sendAlert(alertUnexpectedMessage)
+		return unexpectedMessageError(hs.serverHello, msg)
+	}
+
+	if c.wireVersion != hs.serverHello.vers {
+		c.sendAlert(alertIllegalParameter)
+		return fmt.Errorf("tls: server sent non-matching version %x vs %x", c.wireVersion, hs.serverHello.vers)
+	}
+
+	if hs.suite.id != hs.serverHello.cipherSuite {
+		c.sendAlert(alertIllegalParameter)
+		return fmt.Errorf("tls: server sent non-matching cipher suite %04x vs %04x", hs.suite.id, hs.serverHello.cipherSuite)
+	}
+
+	if haveHelloRetryRequest && helloRetryRequest.hasSelectedGroup && helloRetryRequest.selectedGroup != hs.serverHello.keyShare.group {
+		c.sendAlert(alertHandshakeFailure)
+		return errors.New("tls: ServerHello parameters did not match HelloRetryRequest")
+	}
+
 	if !bytes.Equal(hs.hello.sessionID, hs.serverHello.sessionID) {
 		return errors.New("tls: session IDs did not match.")
 	}
@@ -883,9 +866,6 @@
 	zeroSecret := hs.finishedHash.zeroSecret()
 
 	// Resolve PSK and compute the early secret.
-	//
-	// TODO(davidben): This will need to be handled slightly earlier once
-	// 0-RTT is implemented.
 	if hs.serverHello.hasPSKIdentity {
 		// We send at most one PSK identity.
 		if hs.session == nil || hs.serverHello.pskIdentity != 0 {
@@ -940,10 +920,6 @@
 		return errors.New("tls: server indicated ECH acceptance")
 	}
 
-	// Once the PRF hash is known, TLS 1.3 does not require a handshake
-	// buffer.
-	hs.finishedHash.discardHandshakeBuffer()
-
 	hs.writeServerHash(hs.serverHello.marshal())
 
 	// Derive handshake traffic keys and switch read key to handshake
@@ -1815,6 +1791,35 @@
 func (hs *clientHandshakeState) processServerHello() (bool, error) {
 	c := hs.c
 
+	// Check for downgrade signals in the server random, per RFC 8446, section 4.1.3.
+	gotDowngrade := hs.serverHello.random[len(hs.serverHello.random)-8:]
+	if !c.config.Bugs.IgnoreTLS13DowngradeRandom {
+		if c.config.maxVersion(c.isDTLS) >= VersionTLS13 {
+			if bytes.Equal(gotDowngrade, downgradeTLS13) {
+				c.sendAlert(alertProtocolVersion)
+				return false, errors.New("tls: downgrade from TLS 1.3 detected")
+			}
+		}
+		if c.vers <= VersionTLS11 && c.config.maxVersion(c.isDTLS) >= VersionTLS12 {
+			if bytes.Equal(gotDowngrade, downgradeTLS12) {
+				c.sendAlert(alertProtocolVersion)
+				return false, errors.New("tls: downgrade from TLS 1.2 detected")
+			}
+		}
+	}
+
+	if bytes.Equal(gotDowngrade, downgradeJDK11) != c.config.Bugs.ExpectJDK11DowngradeRandom {
+		c.sendAlert(alertProtocolVersion)
+		if c.config.Bugs.ExpectJDK11DowngradeRandom {
+			return false, errors.New("tls: server did not send a JDK 11 downgrade signal")
+		}
+		return false, errors.New("tls: server sent an unexpected JDK 11 downgrade signal")
+	}
+
+	if c.config.Bugs.ExpectOmitExtensions && !hs.serverHello.omitExtensions {
+		return false, errors.New("tls: ServerHello did not omit extensions")
+	}
+
 	if hs.serverResumedSession() {
 		// For test purposes, assert that the server never accepts the
 		// resumption offer on renegotiation.