Add a test for certificate types parsing.

Change-Id: Icddd39ae183f981f78a65427a4dda34449ca389a
Reviewed-on: https://boringssl-review.googlesource.com/1111
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc
index d0117da..72be47e 100644
--- a/ssl/test/bssl_shim.cc
+++ b/ssl/test/bssl_shim.cc
@@ -28,35 +28,37 @@
 int select_certificate_callback(const struct ssl_early_callback_ctx *ctx) {
   early_callback_called = 1;
 
-  if (expected_server_name) {
-    const unsigned char *extension_data;
-    size_t extension_len;
-    CBS extension, server_name_list, host_name;
-    uint8_t name_type;
+  if (!expected_server_name) {
+    return 1;
+  }
 
-    if (!SSL_early_callback_ctx_extension_get(ctx, TLSEXT_TYPE_server_name,
-                                              &extension_data,
-                                              &extension_len)) {
-      fprintf(stderr, "Could not find server_name extension.");
-      return -1;
-    }
+  const uint8_t *extension_data;
+  size_t extension_len;
+  CBS extension, server_name_list, host_name;
+  uint8_t name_type;
 
-    CBS_init(&extension, extension_data, extension_len);
-    if (!CBS_get_u16_length_prefixed(&extension, &server_name_list) ||
-        CBS_len(&extension) != 0 ||
-        !CBS_get_u8(&server_name_list, &name_type) ||
-        name_type != TLSEXT_NAMETYPE_host_name ||
-        !CBS_get_u16_length_prefixed(&server_name_list, &host_name) ||
-        CBS_len(&server_name_list) != 0) {
-      fprintf(stderr, "Could not decode server_name extension.");
-      return -1;
-    }
+  if (!SSL_early_callback_ctx_extension_get(ctx, TLSEXT_TYPE_server_name,
+                                            &extension_data,
+                                            &extension_len)) {
+    fprintf(stderr, "Could not find server_name extension.\n");
+    return -1;
+  }
 
-    if (CBS_len(&host_name) != strlen(expected_server_name) ||
-        memcmp(expected_server_name,
-               CBS_data(&host_name), CBS_len(&host_name)) != 0) {
-      fprintf(stderr, "Server name mismatch.");
-    }
+  CBS_init(&extension, extension_data, extension_len);
+  if (!CBS_get_u16_length_prefixed(&extension, &server_name_list) ||
+      CBS_len(&extension) != 0 ||
+      !CBS_get_u8(&server_name_list, &name_type) ||
+      name_type != TLSEXT_NAMETYPE_host_name ||
+      !CBS_get_u16_length_prefixed(&server_name_list, &host_name) ||
+      CBS_len(&server_name_list) != 0) {
+    fprintf(stderr, "Could not decode server_name extension.\n");
+    return -1;
+  }
+
+  if (CBS_len(&host_name) != strlen(expected_server_name) ||
+      memcmp(expected_server_name,
+             CBS_data(&host_name), CBS_len(&host_name)) != 0) {
+    fprintf(stderr, "Server name mismatch.\n");
   }
 
   return 1;
@@ -117,6 +119,7 @@
 
 int main(int argc, char **argv) {
   int i, is_server, ret;
+  const char *expected_certificate_types = NULL;
 
   if (argc < 2) {
     fprintf(stderr, "Usage: %s (client|server) [flags...]\n", argv[0]);
@@ -170,6 +173,14 @@
         return 1;
       }
       expected_server_name = argv[i];
+    } else if (strcmp(argv[i], "-expect-certificate-types") == 0) {
+      i++;
+      if (i >= argc) {
+        fprintf(stderr, "Missing parameter\n");
+        return 1;
+      }
+      // Conveniently, 00 is not a certificate type.
+      expected_certificate_types = argv[i];
     } else {
       fprintf(stderr, "Unknown argument: %s\n", argv[i]);
       return 1;
@@ -202,6 +213,19 @@
     }
   }
 
+  if (expected_certificate_types) {
+    uint8_t *certificate_types;
+    int num_certificate_types =
+      SSL_get0_certificate_types(ssl, &certificate_types);
+    if (num_certificate_types != (int)strlen(expected_certificate_types) ||
+        memcmp(certificate_types,
+               expected_certificate_types,
+               num_certificate_types) != 0) {
+      fprintf(stderr, "certificate types mismatch\n");
+      return 2;
+    }
+  }
+
   for (;;) {
     uint8_t buf[512];
     int n = SSL_read(ssl, buf, sizeof(buf));
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go
index df7cacf..dca3e9d 100644
--- a/ssl/test/runner/common.go
+++ b/ssl/test/runner/common.go
@@ -105,15 +105,15 @@
 
 // Certificate types (for certificateRequestMsg)
 const (
-	certTypeRSASign    = 1 // A certificate containing an RSA key
-	certTypeDSSSign    = 2 // A certificate containing a DSA key
-	certTypeRSAFixedDH = 3 // A certificate containing a static DH key
-	certTypeDSSFixedDH = 4 // A certificate containing a static DH key
+	CertTypeRSASign    = 1 // A certificate containing an RSA key
+	CertTypeDSSSign    = 2 // A certificate containing a DSA key
+	CertTypeRSAFixedDH = 3 // A certificate containing a static DH key
+	CertTypeDSSFixedDH = 4 // A certificate containing a static DH key
 
 	// See RFC4492 sections 3 and 5.5.
-	certTypeECDSASign      = 64 // A certificate containing an ECDSA-capable public key, signed with ECDSA.
-	certTypeRSAFixedECDH   = 65 // A certificate containing an ECDH-capable public key, signed with RSA.
-	certTypeECDSAFixedECDH = 66 // A certificate containing an ECDH-capable public key, signed with ECDSA.
+	CertTypeECDSASign      = 64 // A certificate containing an ECDSA-capable public key, signed with ECDSA.
+	CertTypeRSAFixedECDH   = 65 // A certificate containing an ECDH-capable public key, signed with RSA.
+	CertTypeECDSAFixedECDH = 66 // A certificate containing an ECDH-capable public key, signed with ECDSA.
 
 	// Rest of these are reserved by the TLS spec
 )
@@ -251,6 +251,10 @@
 	// by the policy in ClientAuth.
 	ClientCAs *x509.CertPool
 
+	// ClientCertificateTypes defines the set of allowed client certificate
+	// types. The default is CertTypeRSASign and CertTypeECDSASign.
+	ClientCertificateTypes []byte
+
 	// InsecureSkipVerify controls whether a client verifies the
 	// server's certificate chain and host name.
 	// If InsecureSkipVerify is true, TLS accepts any certificate
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index 220e489..890a8a0 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -325,9 +325,9 @@
 		var rsaAvail, ecdsaAvail bool
 		for _, certType := range certReq.certificateTypes {
 			switch certType {
-			case certTypeRSASign:
+			case CertTypeRSASign:
 				rsaAvail = true
-			case certTypeECDSASign:
+			case CertTypeECDSASign:
 				ecdsaAvail = true
 			}
 		}
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index 328c15f..a32c078 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -337,10 +337,14 @@
 
 	if config.ClientAuth >= RequestClientCert {
 		// Request a client certificate
-		certReq := new(certificateRequestMsg)
-		certReq.certificateTypes = []byte{
-			byte(certTypeRSASign),
-			byte(certTypeECDSASign),
+		certReq := &certificateRequestMsg{
+			certificateTypes: config.ClientCertificateTypes,
+		}
+		if certReq.certificateTypes == nil {
+			certReq.certificateTypes = []byte{
+				byte(CertTypeRSASign),
+				byte(CertTypeECDSASign),
+			}
 		}
 		if c.vers >= VersionTLS12 {
 			certReq.hasSignatureAndHash = true
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index 96b52fa..7b1462a 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -165,6 +165,22 @@
 		shouldFail:         true,
 		expectedLocalError: "remote error: error decoding message",
 	},
+	{
+		name: "ClientCertificateTypes",
+		config: Config{
+			ClientAuth: RequestClientCert,
+			ClientCertificateTypes: []byte{
+				CertTypeDSSSign,
+				CertTypeRSASign,
+				CertTypeECDSASign,
+			},
+		},
+		flags: []string{"-expect-certificate-types", string([]byte{
+			CertTypeDSSSign,
+			CertTypeRSASign,
+			CertTypeECDSASign,
+		})},
+	},
 }
 
 func doExchange(tlsConn *Conn, messageLen int) error {