Wire up ML-DSA keygen to the EVP API

This works for both EVP_PKEY_CTX_new_id and also the more friendly
EVP_PKEY_ALG API.

Bug: 449751916
Change-Id: Id342948327b627e0c3ab29e02349e1f29a0fa4c8
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/90609
Reviewed-by: Lily Chen <chlily@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/crypto/evp/evp_ctx.cc b/crypto/evp/evp_ctx.cc
index 66c3e16..d06d4d5 100644
--- a/crypto/evp/evp_ctx.cc
+++ b/crypto/evp/evp_ctx.cc
@@ -89,6 +89,15 @@
     case EVP_PKEY_HKDF:
       alg = evp_pkey_hkdf();
       break;
+    case EVP_PKEY_ML_DSA_44:
+      alg = EVP_pkey_ml_dsa_44();
+      break;
+    case EVP_PKEY_ML_DSA_65:
+      alg = EVP_pkey_ml_dsa_65();
+      break;
+    case EVP_PKEY_ML_DSA_87:
+      alg = EVP_pkey_ml_dsa_87();
+      break;
   }
   if (alg == nullptr || alg->pkey_method == nullptr) {
     OPENSSL_PUT_ERROR(EVP, EVP_R_UNSUPPORTED_ALGORITHM);
diff --git a/crypto/evp/evp_extra_test.cc b/crypto/evp/evp_extra_test.cc
index aa96715..511a85f 100644
--- a/crypto/evp/evp_extra_test.cc
+++ b/crypto/evp/evp_extra_test.cc
@@ -1046,34 +1046,79 @@
   }
 }
 
+static void CheckSignAndVerify(EVP_PKEY *pkey) {
+  auto msg = StringAsBytes("hello");
+  ScopedEVP_MD_CTX ctx;
+  ASSERT_TRUE(EVP_DigestSignInit(ctx.get(), nullptr, nullptr, nullptr, pkey));
+  size_t len;
+  ASSERT_TRUE(EVP_DigestSign(ctx.get(), nullptr, &len, msg.data(), msg.size()));
+  std::vector<uint8_t> sig(len);
+  ASSERT_TRUE(
+      EVP_DigestSign(ctx.get(), sig.data(), &len, msg.data(), msg.size()));
+  sig.resize(len);
+
+  ctx.Reset();
+  ASSERT_TRUE(EVP_DigestVerifyInit(ctx.get(), nullptr, nullptr, nullptr, pkey));
+  ASSERT_TRUE(EVP_DigestVerify(ctx.get(), sig.data(), sig.size(), msg.data(),
+                               msg.size()));
+}
+
 // Test that |EVP_PKEY_keygen| works for Ed25519.
 TEST(EVPExtraTest, Ed25519Keygen) {
-  auto check_sign_and_verify = [&](EVP_PKEY *pkey) {
-    ScopedEVP_MD_CTX ctx;
-    ASSERT_TRUE(EVP_DigestSignInit(ctx.get(), nullptr, nullptr, nullptr, pkey));
-    uint8_t sig[64];
-    size_t len = sizeof(sig);
-    ASSERT_TRUE(EVP_DigestSign(ctx.get(), sig, &len,
-                               reinterpret_cast<const uint8_t *>("hello"), 5));
-
-    ctx.Reset();
-    ASSERT_TRUE(
-        EVP_DigestVerifyInit(ctx.get(), nullptr, nullptr, nullptr, pkey));
-    ASSERT_TRUE(EVP_DigestVerify(
-        ctx.get(), sig, len, reinterpret_cast<const uint8_t *>("hello"), 5));
-  };
-
   UniquePtr<EVP_PKEY_CTX> pctx(EVP_PKEY_CTX_new_id(EVP_PKEY_ED25519, nullptr));
   ASSERT_TRUE(pctx);
   ASSERT_TRUE(EVP_PKEY_keygen_init(pctx.get()));
   EVP_PKEY *raw = nullptr;
   ASSERT_TRUE(EVP_PKEY_keygen(pctx.get(), &raw));
   UniquePtr<EVP_PKEY> pkey(raw);
-  check_sign_and_verify(pkey.get());
+  CheckSignAndVerify(pkey.get());
 
   pkey.reset(EVP_PKEY_generate_from_alg(EVP_pkey_ed25519()));
   ASSERT_TRUE(pkey);
-  check_sign_and_verify(pkey.get());
+  CheckSignAndVerify(pkey.get());
+}
+
+TEST(EVPExtraTest, MLDSAKeyGen) {
+  const struct {
+    int type;
+    const EVP_PKEY_ALG *alg;
+  } kAlgs[] = {
+      {EVP_PKEY_ML_DSA_44, EVP_pkey_ml_dsa_44()},
+      {EVP_PKEY_ML_DSA_65, EVP_pkey_ml_dsa_65()},
+      {EVP_PKEY_ML_DSA_87, EVP_pkey_ml_dsa_87()},
+  };
+  for (const auto &alg : kAlgs) {
+    SCOPED_TRACE(alg.type);
+    UniquePtr<EVP_PKEY_CTX> ctx(EVP_PKEY_CTX_new_id(alg.type, nullptr));
+    ASSERT_TRUE(ctx);
+    ASSERT_TRUE(EVP_PKEY_keygen_init(ctx.get()));
+    EVP_PKEY *raw = nullptr;
+    ASSERT_TRUE(EVP_PKEY_keygen(ctx.get(), &raw));
+    UniquePtr<EVP_PKEY> pkey(raw);
+
+    // The key should be functional. Test that we can round-trip a signature.
+    EXPECT_EQ(EVP_PKEY_id(pkey.get()), alg.type);
+    CheckSignAndVerify(pkey.get());
+
+    // Generate a second key. They should be different.
+    ctx.reset(EVP_PKEY_CTX_new_id(alg.type, nullptr));
+    ASSERT_TRUE(ctx);
+    ASSERT_TRUE(EVP_PKEY_keygen_init(ctx.get()));
+    raw = nullptr;
+    ASSERT_TRUE(EVP_PKEY_keygen(ctx.get(), &raw));
+    UniquePtr<EVP_PKEY> pkey2(raw);
+
+    EXPECT_EQ(EVP_PKEY_id(pkey2.get()), alg.type);
+    CheckSignAndVerify(pkey2.get());
+    EXPECT_EQ(EVP_PKEY_cmp(pkey.get(), pkey2.get()), 0);
+
+    // The less messy API should also work.
+    UniquePtr<EVP_PKEY> pkey3(EVP_PKEY_generate_from_alg(alg.alg));
+    ASSERT_TRUE(pkey3);
+    EXPECT_EQ(EVP_PKEY_id(pkey3.get()), alg.type);
+    CheckSignAndVerify(pkey3.get());
+    EXPECT_EQ(EVP_PKEY_cmp(pkey.get(), pkey3.get()), 0);
+  }
 }
 
 // Test that OpenSSL's legacy TLS-specific APIs in EVP work correctly. When we
diff --git a/crypto/evp/p_mldsa.cc b/crypto/evp/p_mldsa.cc
index 76008e3..6b763a9 100644
--- a/crypto/evp/p_mldsa.cc
+++ b/crypto/evp/p_mldsa.cc
@@ -50,6 +50,7 @@
     static constexpr Span<const uint8_t> kOID = kMLDSA##kl##OID;              \
     static constexpr auto PrivateKeyFromSeed =                                \
         &MLDSA##kl##_private_key_from_seed;                                   \
+    static constexpr auto GenerateKey = &MLDSA##kl##_generate_key;            \
     static constexpr auto Sign = &MLDSA##kl##_sign;                           \
     static constexpr auto ParsePublicKey = &MLDSA##kl##_parse_public_key;     \
     static constexpr auto PublicOfPrivate =                                   \
@@ -377,6 +378,19 @@
     return 1;
   }
 
+  static int KeyGen(EvpPkeyCtx *ctx, EvpPkey *pkey) {
+    auto priv = MakeUnique<PrivateKeyData<Traits>>();
+    if (priv == nullptr) {
+      return 0;
+    }
+    uint8_t unused_public[Traits::kPublicKeyBytes];
+    if (!Traits::GenerateKey(unused_public, priv->seed, &priv->priv)) {
+      return 0;
+    }
+    evp_pkey_set0(pkey, &asn1_method, priv.release());
+    return 1;
+  }
+
   static int SignMessage(EvpPkeyCtx *ctx, uint8_t *sig, size_t *siglen,
                          const uint8_t *tbs, size_t tbslen) {
     const auto *priv_data = GetKeyData(ctx->pkey.get())->AsPrivateKeyData();
@@ -437,8 +451,7 @@
       &Init,
       &CopyContext,
       &Cleanup,
-      // TODO(crbug.com/449751916): Add keygen support.
-      /*keygen=*/nullptr,
+      &KeyGen,
       /*sign=*/nullptr,
       &SignMessage,
       /*verify=*/nullptr,