runner: Move writeHash to the finishedHash struct.
This avoids duplicating some code in client and server. It should also
clean up some ECH test code, which needs to juggle a pair of transcripts
for a brief window.
Change-Id: I4db11119e34b56453f01b5890060b8d4129a25b9
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/46564
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 510cbce..2701fd0 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -754,7 +754,7 @@
session: session,
}
- hs.writeHash(helloBytes, hs.c.sendHandshakeSeq-1)
+ hs.finishedHash.WriteHandshake(helloBytes, hs.c.sendHandshakeSeq-1)
if haveHelloRetryRequest {
err = hs.finishedHash.UpdateForHelloRetryRequest()
if err != nil {
@@ -1948,7 +1948,7 @@
c.clientProtocolFallback = fallback
nextProtoBytes := nextProto.marshal()
- hs.writeHash(nextProtoBytes, seqno)
+ hs.finishedHash.WriteHandshake(nextProtoBytes, seqno)
seqno++
postCCSMsgs = append(postCCSMsgs, nextProtoBytes)
}
@@ -1962,7 +1962,7 @@
if err != nil {
return err
}
- hs.writeHash(channelIDMsgBytes, seqno)
+ hs.finishedHash.WriteHandshake(channelIDMsgBytes, seqno)
seqno++
postCCSMsgs = append(postCCSMsgs, channelIDMsgBytes)
}
@@ -1979,7 +1979,7 @@
}
c.clientVerify = append(c.clientVerify[:0], finished.verifyData...)
hs.finishedBytes = finished.marshal()
- hs.writeHash(hs.finishedBytes, seqno)
+ hs.finishedHash.WriteHandshake(hs.finishedBytes, seqno)
if c.config.Bugs.PartialClientFinishedWithClientHello {
// The first byte has already been written.
postCCSMsgs = append(postCCSMsgs, hs.finishedBytes[1:])
@@ -2055,28 +2055,12 @@
func (hs *clientHandshakeState) writeClientHash(msg []byte) {
// writeClientHash is called before writeRecord.
- hs.writeHash(msg, hs.c.sendHandshakeSeq)
+ hs.finishedHash.WriteHandshake(msg, hs.c.sendHandshakeSeq)
}
func (hs *clientHandshakeState) writeServerHash(msg []byte) {
// writeServerHash is called after readHandshake.
- hs.writeHash(msg, hs.c.recvHandshakeSeq-1)
-}
-
-func (hs *clientHandshakeState) writeHash(msg []byte, seqno uint16) {
- if hs.c.isDTLS {
- // This is somewhat hacky. DTLS hashes a slightly different format.
- // First, the TLS header.
- hs.finishedHash.Write(msg[:4])
- // Then the sequence number and reassembled fragment offset (always 0).
- hs.finishedHash.Write([]byte{byte(seqno >> 8), byte(seqno), 0, 0, 0})
- // Then the reassembled fragment (always equal to the message length).
- hs.finishedHash.Write(msg[1:4])
- // And then the message body.
- hs.finishedHash.Write(msg[4:])
- } else {
- hs.finishedHash.Write(msg)
- }
+ hs.finishedHash.WriteHandshake(msg, hs.c.recvHandshakeSeq-1)
}
// selectClientCertificate selects a certificate for use with the given
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 907ea6f..33ecb13 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -2186,28 +2186,12 @@
func (hs *serverHandshakeState) writeServerHash(msg []byte) {
// writeServerHash is called before writeRecord.
- hs.writeHash(msg, hs.c.sendHandshakeSeq)
+ hs.finishedHash.WriteHandshake(msg, hs.c.sendHandshakeSeq)
}
func (hs *serverHandshakeState) writeClientHash(msg []byte) {
// writeClientHash is called after readHandshake.
- hs.writeHash(msg, hs.c.recvHandshakeSeq-1)
-}
-
-func (hs *serverHandshakeState) writeHash(msg []byte, seqno uint16) {
- if hs.c.isDTLS {
- // This is somewhat hacky. DTLS hashes a slightly different format.
- // First, the TLS header.
- hs.finishedHash.Write(msg[:4])
- // Then the sequence number and reassembled fragment offset (always 0).
- hs.finishedHash.Write([]byte{byte(seqno >> 8), byte(seqno), 0, 0, 0})
- // Then the reassembled fragment (always equal to the message length).
- hs.finishedHash.Write(msg[1:4])
- // And then the message body.
- hs.finishedHash.Write(msg[4:])
- } else {
- hs.finishedHash.Write(msg)
- }
+ hs.finishedHash.WriteHandshake(msg, hs.c.recvHandshakeSeq-1)
}
// tryCipherSuite returns a cipherSuite with the given id if that cipher suite
diff --git a/ssl/test/runner/prf.go b/ssl/test/runner/prf.go
index b27f038..5f2bf49 100644
--- a/ssl/test/runner/prf.go
+++ b/ssl/test/runner/prf.go
@@ -215,6 +215,7 @@
ret.buffer = []byte{}
ret.version = version
ret.wireVersion = wireVersion
+ ret.isDTLS = isDTLS
return ret
}
@@ -236,6 +237,7 @@
version uint16
wireVersion uint16
+ isDTLS bool
prf func(result, secret, label, seed []byte)
// secret, in TLS 1.3, is the running input secret.
@@ -272,6 +274,25 @@
return len(msg), nil
}
+// WriteHandshake appends |msg| to the hash, which must be a serialized
+// handshake message with a TLS header. In DTLS, the header is rewritten to a
+// DTLS header with |seqno| as the sequence number.
+func (h *finishedHash) WriteHandshake(msg []byte, seqno uint16) {
+ if h.isDTLS {
+ // This is somewhat hacky. DTLS hashes a slightly different format.
+ // First, the TLS header.
+ h.Write(msg[:4])
+ // Then the sequence number and reassembled fragment offset (always 0).
+ h.Write([]byte{byte(seqno >> 8), byte(seqno), 0, 0, 0})
+ // Then the reassembled fragment (always equal to the message length).
+ h.Write(msg[1:4])
+ // And then the message body.
+ h.Write(msg[4:])
+ } else {
+ h.Write(msg)
+ }
+}
+
func (h finishedHash) Sum() []byte {
if h.version >= VersionTLS12 {
return h.client.Sum(nil)