Write a CBS-based RSA-PSS parameter parser

For now this is just used by the X.509 verification logic and supports
the three sets of parameters we accept. Later EVP_PKEY_RSA_PSS will use
the same parser.

Bug: 384818542
Change-Id: I4d9a83283cbf239268f37d814f00d0b6ec3b46f2
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/81655
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/crypto/digest/digest_extra.cc b/crypto/digest/digest_extra.cc
index 4312142..03df5f7 100644
--- a/crypto/digest/digest_extra.cc
+++ b/crypto/digest/digest_extra.cc
@@ -23,6 +23,7 @@
 #include <openssl/nid.h>
 #include <openssl/obj.h>
 #include <openssl/sha.h>
+#include <openssl/span.h>
 
 #include "../asn1/internal.h"
 #include "../fipsmodule/digest/internal.h"
@@ -102,42 +103,41 @@
     {{0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04}, 9, NID_sha224},
 };
 
-static const EVP_MD *cbs_to_md(const CBS *cbs) {
-  for (size_t i = 0; i < OPENSSL_ARRAY_SIZE(kMDOIDs); i++) {
-    if (CBS_len(cbs) == kMDOIDs[i].oid_len &&
-        OPENSSL_memcmp(CBS_data(cbs), kMDOIDs[i].oid, kMDOIDs[i].oid_len) ==
-            0) {
-      return EVP_get_digestbynid(kMDOIDs[i].nid);
+static int cbs_to_digest_nid(const CBS *cbs) {
+  for (const auto &md : kMDOIDs) {
+    if (bssl::Span<const uint8_t>(*cbs) ==
+        bssl::Span(md.oid).first(md.oid_len)) {
+      return md.nid;
     }
   }
-
-  return NULL;
+  return NID_undef;
 }
 
 const EVP_MD *EVP_get_digestbyobj(const ASN1_OBJECT *obj) {
-  // Handle objects with no corresponding OID. Note we don't use |OBJ_obj2nid|
-  // here to avoid pulling in the OID table.
-  if (obj->nid != NID_undef) {
-    return EVP_get_digestbynid(obj->nid);
+  int nid = obj->nid;
+  if (nid == NID_undef) {
+    // Handle objects with no saved NID. Note we don't use |OBJ_obj2nid| here to
+    // avoid pulling in the OID table.
+    CBS cbs;
+    CBS_init(&cbs, OBJ_get0_data(obj), OBJ_length(obj));
+    nid = cbs_to_digest_nid(&cbs);
   }
 
-  CBS cbs;
-  CBS_init(&cbs, OBJ_get0_data(obj), OBJ_length(obj));
-  return cbs_to_md(&cbs);
+  return nid == NID_undef ? nullptr : EVP_get_digestbynid(nid);
 }
 
-const EVP_MD *EVP_parse_digest_algorithm(CBS *cbs) {
+int EVP_parse_digest_algorithm_nid(CBS *cbs) {
   CBS algorithm, oid;
   if (!CBS_get_asn1(cbs, &algorithm, CBS_ASN1_SEQUENCE) ||
       !CBS_get_asn1(&algorithm, &oid, CBS_ASN1_OBJECT)) {
     OPENSSL_PUT_ERROR(DIGEST, DIGEST_R_DECODE_ERROR);
-    return NULL;
+    return NID_undef;
   }
 
-  const EVP_MD *ret = cbs_to_md(&oid);
-  if (ret == NULL) {
+  int ret = cbs_to_digest_nid(&oid);
+  if (ret == NID_undef) {
     OPENSSL_PUT_ERROR(DIGEST, DIGEST_R_UNKNOWN_HASH);
-    return NULL;
+    return NID_undef;
   }
 
   // The parameters, if present, must be NULL. Historically, whether the NULL
@@ -150,13 +150,21 @@
         CBS_len(&param) != 0 ||  //
         CBS_len(&algorithm) != 0) {
       OPENSSL_PUT_ERROR(DIGEST, DIGEST_R_DECODE_ERROR);
-      return NULL;
+      return NID_undef;
     }
   }
 
   return ret;
 }
 
+const EVP_MD *EVP_parse_digest_algorithm(CBS *cbs) {
+  int nid = EVP_parse_digest_algorithm_nid(cbs);
+  if (nid == NID_undef) {
+    return nullptr;
+  }
+  return EVP_get_digestbynid(nid);
+}
+
 static int marshal_digest_algorithm(CBB *cbb, const EVP_MD *md,
                                     bool with_null) {
   CBB algorithm, oid, null;
diff --git a/crypto/rsa/internal.h b/crypto/rsa/internal.h
index 21552bf..5fb4f17 100644
--- a/crypto/rsa/internal.h
+++ b/crypto/rsa/internal.h
@@ -28,6 +28,31 @@
                                       size_t param_len, const EVP_MD *md,
                                       const EVP_MD *mgf1md);
 
+enum rsa_pss_params_t {
+  // RSA-PSS using SHA-256, MGF1 with SHA-256, salt length 32.
+  rsa_pss_sha256,
+  // RSA-PSS using SHA-384, MGF1 with SHA-384, salt length 48.
+  rsa_pss_sha384,
+  // RSA-PSS using SHA-512, MGF1 with SHA-512, salt length 64.
+  rsa_pss_sha512,
+};
+
+// rsa_pss_params_get_md returns the hash function used with |params|. This also
+// specifies the MGF-1 hash and the salt length because we do not support other
+// configurations.
+const EVP_MD *rsa_pss_params_get_md(rsa_pss_params_t params);
+
+// rsa_marshal_pss_params marshals |params| as a DER-encoded RSASSA-PSS-params
+// (RFC 4055). It returns one on success and zero on error.
+int rsa_marshal_pss_params(CBB *cbb, rsa_pss_params_t params);
+
+// rsa_marshal_pss_params decodes a DER-encoded RSASSA-PSS-params
+// (RFC 4055). It returns one on success and zero on error. On success, it sets
+// |*out| to the result. If |allow_explicit_trailer| is non-zero, an explicit
+// encoding of the trailerField is allowed, although it is not valid DER.
+int rsa_parse_pss_params(CBS *cbs, rsa_pss_params_t *out,
+                         int allow_explicit_trailer);
+
 
 #if defined(__cplusplus)
 }  // extern C
diff --git a/crypto/rsa/rsa_asn1.cc b/crypto/rsa/rsa_asn1.cc
index a71cc3a..1a1fab3 100644
--- a/crypto/rsa/rsa_asn1.cc
+++ b/crypto/rsa/rsa_asn1.cc
@@ -20,12 +20,17 @@
 
 #include <openssl/bn.h>
 #include <openssl/bytestring.h>
+#include <openssl/digest.h>
 #include <openssl/err.h>
 #include <openssl/mem.h>
+#include <openssl/nid.h>
+#include <openssl/span.h>
+#include <openssl/x509.h>
 
-#include "../fipsmodule/rsa/internal.h"
 #include "../bytestring/internal.h"
+#include "../fipsmodule/rsa/internal.h"
 #include "../internal.h"
+#include "internal.h"
 
 
 static int parse_integer(CBS *cbs, BIGNUM **out) {
@@ -281,3 +286,146 @@
   OPENSSL_free(der);
   return ret;
 }
+
+static const uint8_t kPSSParamsSHA256[] = {
+    0x30, 0x34, 0xa0, 0x0f, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48,
+    0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05, 0x00, 0xa1, 0x1c, 0x30,
+    0x1a, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01,
+    0x08, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03,
+    0x04, 0x02, 0x01, 0x05, 0x00, 0xa2, 0x03, 0x02, 0x01, 0x20};
+
+static const uint8_t kPSSParamsSHA384[] = {
+    0x30, 0x34, 0xa0, 0x0f, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48,
+    0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05, 0x00, 0xa1, 0x1c, 0x30,
+    0x1a, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01,
+    0x08, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03,
+    0x04, 0x02, 0x02, 0x05, 0x00, 0xa2, 0x03, 0x02, 0x01, 0x30};
+
+static const uint8_t kPSSParamsSHA512[] = {
+    0x30, 0x34, 0xa0, 0x0f, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48,
+    0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05, 0x00, 0xa1, 0x1c, 0x30,
+    0x1a, 0x06, 0x09, 0x2a, 0x86, 0x48, 0x86, 0xf7, 0x0d, 0x01, 0x01,
+    0x08, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03,
+    0x04, 0x02, 0x03, 0x05, 0x00, 0xa2, 0x03, 0x02, 0x01, 0x40};
+
+const EVP_MD *rsa_pss_params_get_md(rsa_pss_params_t params) {
+  switch (params) {
+    case rsa_pss_sha256:
+      return EVP_sha256();
+    case rsa_pss_sha384:
+      return EVP_sha384();
+    case rsa_pss_sha512:
+      return EVP_sha512();
+  }
+  abort();
+}
+
+int rsa_marshal_pss_params(CBB *cbb, rsa_pss_params_t params) {
+  bssl::Span<const uint8_t> bytes;
+  switch (params) {
+    case rsa_pss_sha256:
+      bytes = kPSSParamsSHA256;
+      break;
+    case rsa_pss_sha384:
+      bytes = kPSSParamsSHA384;
+      break;
+    case rsa_pss_sha512:
+      bytes = kPSSParamsSHA512;
+      break;
+  }
+
+  return CBB_add_bytes(cbb, bytes.data(), bytes.size());
+}
+
+// 1.2.840.113549.1.1.8
+static const uint8_t kMGF1OID[] = {0x2a, 0x86, 0x48, 0x86, 0xf7,
+                                   0x0d, 0x01, 0x01, 0x08};
+
+int rsa_parse_pss_params(CBS *cbs, rsa_pss_params_t *out,
+                         int allow_explicit_trailer) {
+  // See RFC 4055, section 3.1.
+  //
+  // hashAlgorithm, maskGenAlgorithm, and saltLength all have DEFAULTs
+  // corresponding to SHA-1. We do not support SHA-1 with PSS, so we do not
+  // bother recognizing the omitted versions.
+  CBS params, hash_wrapper, mask_wrapper, mask_alg, mask_oid, salt_wrapper;
+  uint64_t salt_len;
+  if (!CBS_get_asn1(cbs, &params, CBS_ASN1_SEQUENCE) ||
+      !CBS_get_asn1(&params, &hash_wrapper,
+                    CBS_ASN1_CONSTRUCTED | CBS_ASN1_CONTEXT_SPECIFIC | 0) ||
+      // |hash_wrapper| will be parsed below.
+      !CBS_get_asn1(&params, &mask_wrapper,
+                    CBS_ASN1_CONSTRUCTED | CBS_ASN1_CONTEXT_SPECIFIC | 1) ||
+      !CBS_get_asn1(&mask_wrapper, &mask_alg, CBS_ASN1_SEQUENCE) ||
+      !CBS_get_asn1(&mask_alg, &mask_oid, CBS_ASN1_OBJECT) ||
+      // We only support MGF-1.
+      bssl::Span<const uint8_t>(mask_oid) != kMGF1OID ||
+      // The remainder of |mask_alg| will be parsed below.
+      CBS_len(&mask_wrapper) != 0 ||
+      !CBS_get_asn1(&params, &salt_wrapper,
+                    CBS_ASN1_CONSTRUCTED | CBS_ASN1_CONTEXT_SPECIFIC | 2) ||
+      !CBS_get_asn1_uint64(&salt_wrapper, &salt_len) ||
+      CBS_len(&salt_wrapper) != 0) {
+    OPENSSL_PUT_ERROR(RSA, RSA_R_BAD_ENCODING);
+    return 0;
+  }
+
+  // The trailer field must be 1 (0xbc). This value is DEFAULT, so the structure
+  // is required to omit it in DER.
+  if (CBS_len(&params) != 0 && allow_explicit_trailer) {
+    CBS trailer_wrapper;
+    uint64_t trailer;
+    if (!CBS_get_asn1(&params, &trailer_wrapper,
+                      CBS_ASN1_CONSTRUCTED | CBS_ASN1_CONTEXT_SPECIFIC | 3) ||
+        !CBS_get_asn1_uint64(&trailer_wrapper, &trailer) ||  //
+        trailer != 1) {
+      OPENSSL_PUT_ERROR(RSA, RSA_R_BAD_ENCODING);
+      return 0;
+    }
+  }
+  if (CBS_len(&params) != 0) {
+    OPENSSL_PUT_ERROR(RSA, RSA_R_BAD_ENCODING);
+    return 0;
+  }
+
+  int hash_nid = EVP_parse_digest_algorithm_nid(&hash_wrapper);
+  if (hash_nid == NID_undef || CBS_len(&hash_wrapper) != 0) {
+    OPENSSL_PUT_ERROR(RSA, RSA_R_BAD_ENCODING);
+    return 0;
+  }
+
+  // We only support combinations where the MGF-1 hash matches the overall hash.
+  int mgf1_hash_nid = EVP_parse_digest_algorithm_nid(&mask_alg);
+  if (mgf1_hash_nid != hash_nid || CBS_len(&mask_alg) != 0) {
+    OPENSSL_PUT_ERROR(RSA, RSA_R_BAD_ENCODING);
+    return 0;
+  }
+
+  // We only support salt lengths that match the hash length.
+  rsa_pss_params_t ret;
+  uint64_t hash_len;
+  switch (hash_nid) {
+    case NID_sha256:
+      ret = rsa_pss_sha256;
+      hash_len = 32;
+      break;
+    case NID_sha384:
+      ret = rsa_pss_sha384;
+      hash_len = 48;
+      break;
+    case NID_sha512:
+      ret = rsa_pss_sha512;
+      hash_len = 64;
+      break;
+    default:
+      OPENSSL_PUT_ERROR(RSA, RSA_R_BAD_ENCODING);
+      return 0;
+  }
+  if (salt_len != hash_len) {
+    OPENSSL_PUT_ERROR(RSA, RSA_R_BAD_ENCODING);
+    return 0;
+  }
+
+  *out = ret;
+  return 1;
+}
diff --git a/crypto/x509/rsa_pss.cc b/crypto/x509/rsa_pss.cc
index 689f64d..3ad9172 100644
--- a/crypto/x509/rsa_pss.cc
+++ b/crypto/x509/rsa_pss.cc
@@ -24,6 +24,7 @@
 #include <openssl/evp.h>
 #include <openssl/obj.h>
 
+#include "../rsa/internal.h"
 #include "internal.h"
 
 
@@ -46,111 +47,18 @@
 IMPLEMENT_ASN1_FUNCTIONS_const(RSA_PSS_PARAMS)
 
 
-// Given an MGF1 Algorithm ID decode to an Algorithm Identifier
-static X509_ALGOR *rsa_mgf1_decode(const X509_ALGOR *alg) {
-  if (OBJ_obj2nid(alg->algorithm) != NID_mgf1 || alg->parameter == NULL ||
-      alg->parameter->type != V_ASN1_SEQUENCE) {
-    return NULL;
-  }
-
-  const uint8_t *p = alg->parameter->value.sequence->data;
-  int plen = alg->parameter->value.sequence->length;
-  return d2i_X509_ALGOR(NULL, &p, plen);
-}
-
-static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg) {
-  if (alg->parameter == NULL || alg->parameter->type != V_ASN1_SEQUENCE) {
-    return NULL;
-  }
-
-  const uint8_t *p = alg->parameter->value.sequence->data;
-  int plen = alg->parameter->value.sequence->length;
-  return d2i_RSA_PSS_PARAMS(NULL, &p, plen);
-}
-
-static int is_allowed_pss_md(const EVP_MD *md) {
-  int md_type = EVP_MD_type(md);
-  return md_type == NID_sha256 || md_type == NID_sha384 ||
-         md_type == NID_sha512;
-}
-
-// rsa_md_to_algor sets |*palg| to an |X509_ALGOR| describing the digest |md|,
-// which must be an allowed PSS digest.
-static int rsa_md_to_algor(X509_ALGOR **palg, const EVP_MD *md) {
-  // SHA-1 should be omitted (DEFAULT), but we do not allow SHA-1.
-  assert(is_allowed_pss_md(md));
-  *palg = X509_ALGOR_new();
-  if (*palg == NULL) {
+static int rsa_pss_decode(const X509_ALGOR *alg, rsa_pss_params_t *out) {
+  if (alg->parameter == nullptr || alg->parameter->type != V_ASN1_SEQUENCE) {
     return 0;
   }
-  if (!X509_ALGOR_set_md(*palg, md)) {
-    X509_ALGOR_free(*palg);
-    *palg = NULL;
-    return 0;
-  }
-  return 1;
-}
 
-// rsa_md_to_mgf1 sets |*palg| to an |X509_ALGOR| describing MGF-1 with the
-// digest |mgf1md|, which must be an allowed PSS digest.
-static int rsa_md_to_mgf1(X509_ALGOR **palg, const EVP_MD *mgf1md) {
-  // SHA-1 should be omitted (DEFAULT), but we do not allow SHA-1.
-  assert(is_allowed_pss_md(mgf1md));
-  X509_ALGOR *algtmp = NULL;
-  ASN1_STRING *stmp = NULL;
-  // need to embed algorithm ID inside another
-  if (!rsa_md_to_algor(&algtmp, mgf1md) ||
-      !ASN1_item_pack(algtmp, ASN1_ITEM_rptr(X509_ALGOR), &stmp)) {
-    goto err;
-  }
-  *palg = X509_ALGOR_new();
-  if (!*palg) {
-    goto err;
-  }
-  if (!X509_ALGOR_set0(*palg, OBJ_nid2obj(NID_mgf1), V_ASN1_SEQUENCE, stmp)) {
-    goto err;
-  }
-  stmp = NULL;
-
-err:
-  ASN1_STRING_free(stmp);
-  X509_ALGOR_free(algtmp);
-  if (*palg) {
-    return 1;
-  }
-
-  return 0;
-}
-
-static const EVP_MD *rsa_algor_to_md(const X509_ALGOR *alg) {
-  if (!alg) {
-    // If omitted, PSS defaults to SHA-1, which we do not allow.
-    OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
-    return NULL;
-  }
-  const EVP_MD *md = EVP_get_digestbyobj(alg->algorithm);
-  if (md == NULL || !is_allowed_pss_md(md)) {
-    OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
-    return NULL;
-  }
-  return md;
-}
-
-static const EVP_MD *rsa_mgf1_to_md(const X509_ALGOR *alg) {
-  if (!alg) {
-    // If omitted, PSS defaults to MGF-1 with SHA-1, which we do not allow.
-    OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
-    return NULL;
-  }
-  // Check mask and lookup mask hash algorithm.
-  X509_ALGOR *maskHash = rsa_mgf1_decode(alg);
-  if (maskHash == NULL) {
-    OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
-    return NULL;
-  }
-  const EVP_MD *ret = rsa_algor_to_md(maskHash);
-  X509_ALGOR_free(maskHash);
-  return ret;
+  // Although a syntax error in DER, we tolerate an explicitly-encoded trailer.
+  // See the certificates in cl/362617931.
+  CBS cbs;
+  CBS_init(&cbs, alg->parameter->value.sequence->data,
+           alg->parameter->value.sequence->length);
+  return rsa_parse_pss_params(&cbs, out, /*allow_explicit_trailer=*/true) &&
+         CBS_len(&cbs) == 0;
 }
 
 int x509_rsa_ctx_to_pss(EVP_MD_CTX *ctx, X509_ALGOR *algor) {
@@ -162,203 +70,111 @@
     return 0;
   }
 
-  if (sigmd != mgf1md || !is_allowed_pss_md(sigmd)) {
+  if (sigmd != mgf1md) {
     OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
     return 0;
   }
   int md_len = (int)EVP_MD_size(sigmd);
-  if (saltlen == RSA_PSS_SALTLEN_DIGEST) {
-    saltlen = md_len;
-  } else if (saltlen != md_len) {
+  if (saltlen != RSA_PSS_SALTLEN_DIGEST && saltlen != md_len) {
     OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
     return 0;
   }
 
-  int ret = 0;
-  ASN1_STRING *os = NULL;
-  RSA_PSS_PARAMS *pss = RSA_PSS_PARAMS_new();
-  if (!pss) {
-    goto err;
+  rsa_pss_params_t params;
+  switch (EVP_MD_type(sigmd)) {
+    case NID_sha256:
+      params = rsa_pss_sha256;
+      break;
+    case NID_sha384:
+      params = rsa_pss_sha384;
+      break;
+    case NID_sha512:
+      params = rsa_pss_sha512;
+      break;
+    default:
+      OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
+      return 0;
   }
 
-  // The DEFAULT value is 20, but this does not match any supported digest.
-  assert(saltlen != 20);
-  pss->saltLength = ASN1_INTEGER_new();
-  if (!pss->saltLength ||  //
-      !ASN1_INTEGER_set_int64(pss->saltLength, saltlen)) {
-    goto err;
+  // Encode |params| to an |ASN1_STRING|.
+  uint8_t buf[128];   // The largest param fits comfortably in 128 bytes.
+  CBB cbb;
+  CBB_init_fixed(&cbb, buf, sizeof(buf));
+  if (!rsa_marshal_pss_params(&cbb, params)) {
+    return 0;
   }
-
-  if (!rsa_md_to_algor(&pss->hashAlgorithm, sigmd) ||
-      !rsa_md_to_mgf1(&pss->maskGenAlgorithm, mgf1md)) {
-    goto err;
-  }
-
-  // Finally create string with pss parameter encoding.
-  if (!ASN1_item_pack(pss, ASN1_ITEM_rptr(RSA_PSS_PARAMS), &os)) {
-    goto err;
+  bssl::UniquePtr<ASN1_STRING> params_str(
+      ASN1_STRING_type_new(V_ASN1_SEQUENCE));
+  if (params_str == nullptr ||
+      !ASN1_STRING_set(params_str.get(), CBB_data(&cbb), CBB_len(&cbb))) {
+    return 0;
   }
 
   if (!X509_ALGOR_set0(algor, OBJ_nid2obj(NID_rsassaPss), V_ASN1_SEQUENCE,
-                       os)) {
-    goto err;
+                       params_str.get())) {
+    return 0;
   }
-  os = NULL;
-  ret = 1;
-
-err:
-  RSA_PSS_PARAMS_free(pss);
-  ASN1_STRING_free(os);
-  return ret;
+  params_str.release();  // |X509_ALGOR_set0| took ownership.
+  return 1;
 }
 
 int x509_rsa_pss_to_ctx(EVP_MD_CTX *ctx, const X509_ALGOR *sigalg,
                         EVP_PKEY *pkey) {
   assert(OBJ_obj2nid(sigalg->algorithm) == NID_rsassaPss);
-
-  // Decode PSS parameters
-  int ret = 0;
-  RSA_PSS_PARAMS *pss = rsa_pss_decode(sigalg);
-
-  {
-    if (pss == NULL) {
-      OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
-      goto err;
-    }
-
-    const EVP_MD *mgf1md = rsa_mgf1_to_md(pss->maskGenAlgorithm);
-    const EVP_MD *md = rsa_algor_to_md(pss->hashAlgorithm);
-    if (mgf1md == NULL || md == NULL) {
-      goto err;
-    }
-
-    // We require the MGF-1 and signing hashes to match.
-    if (mgf1md != md) {
-      OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
-      goto err;
-    }
-
-    // We require the salt length be the hash length. The DEFAULT value is 20,
-    // but this does not match any supported salt length.
-    uint64_t salt_len = 0;
-    if (pss->saltLength == NULL ||
-        !ASN1_INTEGER_get_uint64(&salt_len, pss->saltLength) ||
-        salt_len != EVP_MD_size(md)) {
-      OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
-      goto err;
-    }
-    assert(salt_len <= INT_MAX);
-
-    // The trailer field must be 1 (0xbc). This value is DEFAULT, so the
-    // structure is required to omit it in DER. Although a syntax error, we also
-    // tolerate an explicitly-encoded value. See the certificates in
-    // cl/362617931.
-    if (pss->trailerField != NULL && ASN1_INTEGER_get(pss->trailerField) != 1) {
-      OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
-      goto err;
-    }
-
-    EVP_PKEY_CTX *pctx;
-    if (!EVP_DigestVerifyInit(ctx, &pctx, md, NULL, pkey) ||
-        !EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) ||
-        !EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, (int)salt_len) ||
-        !EVP_PKEY_CTX_set_rsa_mgf1_md(pctx, mgf1md)) {
-      goto err;
-    }
-
-    ret = 1;
+  rsa_pss_params_t params;
+  if (!rsa_pss_decode(sigalg, &params)) {
+    OPENSSL_PUT_ERROR(X509, X509_R_INVALID_PSS_PARAMETERS);
+    return 0;
   }
 
-err:
-  RSA_PSS_PARAMS_free(pss);
-  return ret;
+  const EVP_MD *md = rsa_pss_params_get_md(params);
+  EVP_PKEY_CTX *pctx;
+  if (!EVP_DigestVerifyInit(ctx, &pctx, md, nullptr, pkey) ||
+      !EVP_PKEY_CTX_set_rsa_padding(pctx, RSA_PKCS1_PSS_PADDING) ||
+      !EVP_PKEY_CTX_set_rsa_pss_saltlen(pctx, RSA_PSS_SALTLEN_DIGEST) ||
+      !EVP_PKEY_CTX_set_rsa_mgf1_md(pctx, md)) {
+    return 0;
+  }
+
+  return 1;
 }
 
 int x509_print_rsa_pss_params(BIO *bp, const X509_ALGOR *sigalg, int indent,
                               ASN1_PCTX *pctx) {
   assert(OBJ_obj2nid(sigalg->algorithm) == NID_rsassaPss);
+  rsa_pss_params_t params;
+  if (!rsa_pss_decode(sigalg, &params)) {
+    return BIO_puts(bp, " (INVALID PSS PARAMETERS)\n") <= 0;
+  }
 
-  int rv = 0;
-  X509_ALGOR *maskHash = NULL;
-  RSA_PSS_PARAMS *pss = rsa_pss_decode(sigalg);
-  if (!pss) {
-    if (BIO_puts(bp, " (INVALID PSS PARAMETERS)\n") <= 0) {
-      goto err;
-    }
-    rv = 1;
-    goto err;
+  const char *hash_str = nullptr;
+  uint32_t salt_len = 0;
+  switch (params) {
+    case rsa_pss_sha256:
+      hash_str = "sha256";
+      salt_len = 32;
+      break;
+    case rsa_pss_sha384:
+      hash_str = "sha384";
+      salt_len = 48;
+      break;
+    case rsa_pss_sha512:
+      hash_str = "sha512";
+      salt_len = 64;
+      break;
   }
 
   if (BIO_puts(bp, "\n") <= 0 ||       //
       !BIO_indent(bp, indent, 128) ||  //
-      BIO_puts(bp, "Hash Algorithm: ") <= 0) {
-    goto err;
-  }
-
-  if (pss->hashAlgorithm) {
-    if (i2a_ASN1_OBJECT(bp, pss->hashAlgorithm->algorithm) <= 0) {
-      goto err;
-    }
-  } else if (BIO_puts(bp, "sha1 (default)") <= 0) {
-    goto err;
-  }
-
-  if (BIO_puts(bp, "\n") <= 0 ||       //
+      BIO_printf(bp, "Hash Algorithm: %s\n", hash_str) <= 0 ||
       !BIO_indent(bp, indent, 128) ||  //
-      BIO_puts(bp, "Mask Algorithm: ") <= 0) {
-    goto err;
+      BIO_printf(bp, "Mask Algorithm: mgf1 with %s\n", hash_str) <= 0 ||
+      !BIO_indent(bp, indent, 128) ||  //
+      BIO_printf(bp, "Salt Length: 0x%x\n", salt_len) <= 0 ||
+      !BIO_indent(bp, indent, 128) ||  //
+      BIO_puts(bp, "Trailer Field: 0xBC (default)\n") <= 0) {
+    return 0;
   }
 
-  if (pss->maskGenAlgorithm) {
-    maskHash = rsa_mgf1_decode(pss->maskGenAlgorithm);
-    if (maskHash == NULL) {
-      if (BIO_puts(bp, "INVALID") <= 0) {
-        goto err;
-      }
-    } else {
-      if (i2a_ASN1_OBJECT(bp, pss->maskGenAlgorithm->algorithm) <= 0 ||
-          BIO_puts(bp, " with ") <= 0 ||
-          i2a_ASN1_OBJECT(bp, maskHash->algorithm) <= 0) {
-        goto err;
-      }
-    }
-  } else if (BIO_puts(bp, "mgf1 with sha1 (default)") <= 0) {
-    goto err;
-  }
-  BIO_puts(bp, "\n");
-
-  if (!BIO_indent(bp, indent, 128) ||  //
-      BIO_puts(bp, "Salt Length: 0x") <= 0) {
-    goto err;
-  }
-
-  if (pss->saltLength) {
-    if (i2a_ASN1_INTEGER(bp, pss->saltLength) <= 0) {
-      goto err;
-    }
-  } else if (BIO_puts(bp, "14 (default)") <= 0) {
-    goto err;
-  }
-  BIO_puts(bp, "\n");
-
-  if (!BIO_indent(bp, indent, 128) ||  //
-      BIO_puts(bp, "Trailer Field: 0x") <= 0) {
-    goto err;
-  }
-
-  if (pss->trailerField) {
-    if (i2a_ASN1_INTEGER(bp, pss->trailerField) <= 0) {
-      goto err;
-    }
-  } else if (BIO_puts(bp, "BC (default)") <= 0) {
-    goto err;
-  }
-  BIO_puts(bp, "\n");
-
-  rv = 1;
-
-err:
-  RSA_PSS_PARAMS_free(pss);
-  X509_ALGOR_free(maskHash);
-  return rv;
+  return 1;
 }
diff --git a/include/openssl/digest.h b/include/openssl/digest.h
index 710c6e6..b604aba 100644
--- a/include/openssl/digest.h
+++ b/include/openssl/digest.h
@@ -213,6 +213,11 @@
 // returns the digest function or NULL on error.
 OPENSSL_EXPORT const EVP_MD *EVP_parse_digest_algorithm(CBS *cbs);
 
+// EVP_parse_digest_algorithm_nid behaves like |EVP_parse_digest_algorithm|
+// except it returns |NID_undef| on error and some other value on success. This
+// may be used to avoid depending on every digest algorithm in the library.
+OPENSSL_EXPORT int EVP_parse_digest_algorithm_nid(CBS *cbs);
+
 // EVP_marshal_digest_algorithm marshals |md| as an AlgorithmIdentifier
 // structure and appends the result to |cbb|. It returns one on success and zero
 // on error. It sets the parameters field to NULL. Use