runner: Move writeHash to the finishedHash struct.

This avoids duplicating some code in client and server. It should also
clean up some ECH test code, which needs to juggle a pair of transcripts
for a brief window.

Change-Id: I4db11119e34b56453f01b5890060b8d4129a25b9
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/46564
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 510cbce..2701fd0 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -754,7 +754,7 @@
 		session:      session,
 	}
 
-	hs.writeHash(helloBytes, hs.c.sendHandshakeSeq-1)
+	hs.finishedHash.WriteHandshake(helloBytes, hs.c.sendHandshakeSeq-1)
 	if haveHelloRetryRequest {
 		err = hs.finishedHash.UpdateForHelloRetryRequest()
 		if err != nil {
@@ -1948,7 +1948,7 @@
 		c.clientProtocolFallback = fallback
 
 		nextProtoBytes := nextProto.marshal()
-		hs.writeHash(nextProtoBytes, seqno)
+		hs.finishedHash.WriteHandshake(nextProtoBytes, seqno)
 		seqno++
 		postCCSMsgs = append(postCCSMsgs, nextProtoBytes)
 	}
@@ -1962,7 +1962,7 @@
 		if err != nil {
 			return err
 		}
-		hs.writeHash(channelIDMsgBytes, seqno)
+		hs.finishedHash.WriteHandshake(channelIDMsgBytes, seqno)
 		seqno++
 		postCCSMsgs = append(postCCSMsgs, channelIDMsgBytes)
 	}
@@ -1979,7 +1979,7 @@
 	}
 	c.clientVerify = append(c.clientVerify[:0], finished.verifyData...)
 	hs.finishedBytes = finished.marshal()
-	hs.writeHash(hs.finishedBytes, seqno)
+	hs.finishedHash.WriteHandshake(hs.finishedBytes, seqno)
 	if c.config.Bugs.PartialClientFinishedWithClientHello {
 		// The first byte has already been written.
 		postCCSMsgs = append(postCCSMsgs, hs.finishedBytes[1:])
@@ -2055,28 +2055,12 @@
 
 func (hs *clientHandshakeState) writeClientHash(msg []byte) {
 	// writeClientHash is called before writeRecord.
-	hs.writeHash(msg, hs.c.sendHandshakeSeq)
+	hs.finishedHash.WriteHandshake(msg, hs.c.sendHandshakeSeq)
 }
 
 func (hs *clientHandshakeState) writeServerHash(msg []byte) {
 	// writeServerHash is called after readHandshake.
-	hs.writeHash(msg, hs.c.recvHandshakeSeq-1)
-}
-
-func (hs *clientHandshakeState) writeHash(msg []byte, seqno uint16) {
-	if hs.c.isDTLS {
-		// This is somewhat hacky. DTLS hashes a slightly different format.
-		// First, the TLS header.
-		hs.finishedHash.Write(msg[:4])
-		// Then the sequence number and reassembled fragment offset (always 0).
-		hs.finishedHash.Write([]byte{byte(seqno >> 8), byte(seqno), 0, 0, 0})
-		// Then the reassembled fragment (always equal to the message length).
-		hs.finishedHash.Write(msg[1:4])
-		// And then the message body.
-		hs.finishedHash.Write(msg[4:])
-	} else {
-		hs.finishedHash.Write(msg)
-	}
+	hs.finishedHash.WriteHandshake(msg, hs.c.recvHandshakeSeq-1)
 }
 
 // selectClientCertificate selects a certificate for use with the given
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 907ea6f..33ecb13 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -2186,28 +2186,12 @@
 
 func (hs *serverHandshakeState) writeServerHash(msg []byte) {
 	// writeServerHash is called before writeRecord.
-	hs.writeHash(msg, hs.c.sendHandshakeSeq)
+	hs.finishedHash.WriteHandshake(msg, hs.c.sendHandshakeSeq)
 }
 
 func (hs *serverHandshakeState) writeClientHash(msg []byte) {
 	// writeClientHash is called after readHandshake.
-	hs.writeHash(msg, hs.c.recvHandshakeSeq-1)
-}
-
-func (hs *serverHandshakeState) writeHash(msg []byte, seqno uint16) {
-	if hs.c.isDTLS {
-		// This is somewhat hacky. DTLS hashes a slightly different format.
-		// First, the TLS header.
-		hs.finishedHash.Write(msg[:4])
-		// Then the sequence number and reassembled fragment offset (always 0).
-		hs.finishedHash.Write([]byte{byte(seqno >> 8), byte(seqno), 0, 0, 0})
-		// Then the reassembled fragment (always equal to the message length).
-		hs.finishedHash.Write(msg[1:4])
-		// And then the message body.
-		hs.finishedHash.Write(msg[4:])
-	} else {
-		hs.finishedHash.Write(msg)
-	}
+	hs.finishedHash.WriteHandshake(msg, hs.c.recvHandshakeSeq-1)
 }
 
 // tryCipherSuite returns a cipherSuite with the given id if that cipher suite
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index b27f038..5f2bf49 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -215,6 +215,7 @@
 	ret.buffer = []byte{}
 	ret.version = version
 	ret.wireVersion = wireVersion
+	ret.isDTLS = isDTLS
 	return ret
 }
 
@@ -236,6 +237,7 @@
 
 	version     uint16
 	wireVersion uint16
+	isDTLS      bool
 	prf         func(result, secret, label, seed []byte)
 
 	// secret, in TLS 1.3, is the running input secret.
@@ -272,6 +274,25 @@
 	return len(msg), nil
 }
 
+// WriteHandshake appends |msg| to the hash, which must be a serialized
+// handshake message with a TLS header. In DTLS, the header is rewritten to a
+// DTLS header with |seqno| as the sequence number.
+func (h *finishedHash) WriteHandshake(msg []byte, seqno uint16) {
+	if h.isDTLS {
+		// This is somewhat hacky. DTLS hashes a slightly different format.
+		// First, the TLS header.
+		h.Write(msg[:4])
+		// Then the sequence number and reassembled fragment offset (always 0).
+		h.Write([]byte{byte(seqno >> 8), byte(seqno), 0, 0, 0})
+		// Then the reassembled fragment (always equal to the message length).
+		h.Write(msg[1:4])
+		// And then the message body.
+		h.Write(msg[4:])
+	} else {
+		h.Write(msg)
+	}
+}
+
 func (h finishedHash) Sum() []byte {
 	if h.version >= VersionTLS12 {
 		return h.client.Sum(nil)