Check invalid context in multialg open-dice

Bug: b/357008987
Change-Id: I9e14428097471fa7b1c28d5950070b2faa66be59
Reviewed-on: https://pigweed-review.googlesource.com/c/open-dice/+/247672
Commit-Queue: Alice Wang <aliceywang@google.com>
Presubmit-Verified: CQ Bot Account <pigweed-scoped@luci-project-accounts.iam.gserviceaccount.com>
Lint: Lint 🤖 <android-build-ayeaye@system.gserviceaccount.com>
Reviewed-by: Dan Fess <dfess@google.com>
Reviewed-by: Darren Krahn <dkrahn@google.com>
diff --git a/include/dice/config/boringssl_multialg/dice/config.h b/include/dice/config/boringssl_multialg/dice/config.h
index d8a965d..1158a1b 100644
--- a/include/dice/config/boringssl_multialg/dice/config.h
+++ b/include/dice/config/boringssl_multialg/dice/config.h
@@ -42,15 +42,22 @@
   DiceKeyAlgorithm subject_algorithm;
 } DiceContext;
 
-static inline DiceKeyAlgorithm DiceGetKeyAlgorithm(void* context,
-                                                   DicePrincipal principal) {
+static inline DiceResult DiceGetKeyAlgorithm(void* context,
+                                             DicePrincipal principal,
+                                             DiceKeyAlgorithm* alg) {
   DiceContext* c = (DiceContext*)context;
+  if (context == NULL) {
+    return kDiceResultInvalidInput;
+  }
   switch (principal) {
     case kDicePrincipalAuthority:
-      return c->authority_algorithm;
+      *alg = c->authority_algorithm;
+      break;
     case kDicePrincipalSubject:
-      return c->subject_algorithm;
+      *alg = c->subject_algorithm;
+      break;
   }
+  return kDiceResultOk;
 }
 
 #ifdef __cplusplus
diff --git a/src/boringssl_multialg_ops.c b/src/boringssl_multialg_ops.c
index 80c0903..745329e 100644
--- a/src/boringssl_multialg_ops.c
+++ b/src/boringssl_multialg_ops.c
@@ -43,7 +43,12 @@
 
 DiceResult DiceGetKeyParam(void* context, DicePrincipal principal,
                            DiceKeyParam* key_param) {
-  switch (DiceGetKeyAlgorithm(context, principal)) {
+  DiceKeyAlgorithm alg;
+  DiceResult result = DiceGetKeyAlgorithm(context, principal, &alg);
+  if (result != kDiceResultOk) {
+    return result;
+  }
+  switch (alg) {
     case kDiceKeyAlgorithmEd25519:
       key_param->profile_name = DICE_PROFILE_NAME_ED25519;
       key_param->public_key_size = 32;
@@ -80,7 +85,12 @@
     const uint8_t seed[DICE_PRIVATE_KEY_SEED_SIZE],
     uint8_t public_key[DICE_PUBLIC_KEY_BUFFER_SIZE],
     uint8_t private_key[DICE_PRIVATE_KEY_BUFFER_SIZE]) {
-  switch (DiceGetKeyAlgorithm(context, principal)) {
+  DiceKeyAlgorithm alg;
+  DiceResult result = DiceGetKeyAlgorithm(context, principal, &alg);
+  if (result != kDiceResultOk) {
+    return result;
+  }
+  switch (alg) {
     case kDiceKeyAlgorithmEd25519:
       ED25519_keypair_from_seed(public_key, private_key, seed);
       return kDiceResultOk;
@@ -101,7 +111,13 @@
 DiceResult DiceSign(void* context, const uint8_t* message, size_t message_size,
                     const uint8_t private_key[DICE_PRIVATE_KEY_BUFFER_SIZE],
                     uint8_t signature[DICE_SIGNATURE_BUFFER_SIZE]) {
-  switch (DiceGetKeyAlgorithm(context, kDicePrincipalAuthority)) {
+  DiceKeyAlgorithm alg;
+  DiceResult result =
+      DiceGetKeyAlgorithm(context, kDicePrincipalAuthority, &alg);
+  if (result != kDiceResultOk) {
+    return result;
+  }
+  switch (alg) {
     case kDiceKeyAlgorithmEd25519:
       if (1 == ED25519_sign(signature, message, message_size, private_key)) {
         return kDiceResultOk;
@@ -125,7 +141,13 @@
                       size_t message_size,
                       const uint8_t signature[DICE_SIGNATURE_BUFFER_SIZE],
                       const uint8_t public_key[DICE_PUBLIC_KEY_BUFFER_SIZE]) {
-  switch (DiceGetKeyAlgorithm(context, kDicePrincipalAuthority)) {
+  DiceKeyAlgorithm alg;
+  DiceResult result =
+      DiceGetKeyAlgorithm(context, kDicePrincipalAuthority, &alg);
+  if (result != kDiceResultOk) {
+    return result;
+  }
+  switch (alg) {
     case kDiceKeyAlgorithmEd25519:
       if (1 == ED25519_verify(message, message_size, signature, public_key)) {
         return kDiceResultOk;
diff --git a/src/cbor_multialg_op_test.cc b/src/cbor_multialg_op_test.cc
index 28886bd..8e8eadd 100644
--- a/src/cbor_multialg_op_test.cc
+++ b/src/cbor_multialg_op_test.cc
@@ -35,6 +35,17 @@
 using dice::test::KeyType_P256;
 using dice::test::KeyType_P384;
 
+TEST(DiceOpsTest, InvalidContextReturnsError) {
+  DiceStateForTest current_state = {};
+  DiceStateForTest next_state = {};
+  DiceInputValues input_values = {};
+  DiceResult result = DiceMainFlow(
+      NULL, current_state.cdi_attest, current_state.cdi_seal, &input_values,
+      sizeof(next_state.certificate), next_state.certificate,
+      &next_state.certificate_size, next_state.cdi_attest, next_state.cdi_seal);
+  EXPECT_EQ(kDiceResultInvalidInput, result);
+}
+
 TEST(DiceOpsTest, Ed25519KnownAnswerZeroInput) {
   DiceContext context{.authority_algorithm = kDiceKeyAlgorithmEd25519,
                       .subject_algorithm = kDiceKeyAlgorithmEd25519};