Add sk_FOO_delete_if.

Change-Id: I9abfef8d3d4d3ce0dabe29a5dc391fd4e0200d7f
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/56030
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: Bob Beck <bbe@google.com>
diff --git a/crypto/stack/stack.c b/crypto/stack/stack.c
index fe1b513..2dbb842 100644
--- a/crypto/stack/stack.c
+++ b/crypto/stack/stack.c
@@ -212,7 +212,7 @@
 
   if (where != sk->num - 1) {
     OPENSSL_memmove(&sk->data[where], &sk->data[where + 1],
-            sizeof(void *) * (sk->num - where - 1));
+                    sizeof(void *) * (sk->num - where - 1));
   }
 
   sk->num--;
@@ -233,6 +233,22 @@
   return NULL;
 }
 
+void sk_delete_if(_STACK *sk, OPENSSL_sk_call_delete_if_func call_func,
+                  OPENSSL_sk_delete_if_func func, void *data) {
+  if (sk == NULL) {
+    return;
+  }
+
+  size_t new_num = 0;
+  for (size_t i = 0; i < sk->num; i++) {
+    if (!call_func(func, sk->data[i], data)) {
+      sk->data[new_num] = sk->data[i];
+      new_num++;
+    }
+  }
+  sk->num = new_num;
+}
+
 int sk_find(const _STACK *sk, size_t *out_index, const void *p,
             OPENSSL_sk_call_cmp_func call_cmp_func) {
   if (sk == NULL) {
diff --git a/crypto/stack/stack_test.cc b/crypto/stack/stack_test.cc
index 7be84ed..9a7832c 100644
--- a/crypto/stack/stack_test.cc
+++ b/crypto/stack/stack_test.cc
@@ -290,6 +290,7 @@
     // Removing elements does not affect sortedness.
     TEST_INT_free(sk_TEST_INT_delete(sk.get(), 0));
     EXPECT_TRUE(sk_TEST_INT_is_sorted(sk.get()));
+    EXPECT_TRUE(sk_TEST_INT_is_sorted(sk.get()));
 
     // Changing the comparison function invalidates sortedness.
     sk_TEST_INT_set_cmp_func(sk.get(), compare_reverse);
@@ -392,3 +393,44 @@
     }
   }
 }
+
+TEST(StackTest, DeleteIf) {
+  bssl::UniquePtr<STACK_OF(TEST_INT)> sk(sk_TEST_INT_new(compare));
+  for (int v : {1, 9, 2, 8, 3, 7, 4, 6, 5}) {
+    auto obj = TEST_INT_new(v);
+    ASSERT_TRUE(obj);
+    ASSERT_TRUE(bssl::PushToStack(sk.get(), std::move(obj)));
+  }
+
+  auto keep_only_multiples = [](TEST_INT *x, void *data) {
+    auto d = static_cast<const int *>(data);
+    if (*x % *d == 0) {
+      return 0;
+    }
+    TEST_INT_free(x);
+    return 1;
+  };
+
+  int d = 2;
+  sk_TEST_INT_delete_if(sk.get(), keep_only_multiples, &d);
+  ExpectStackEquals(sk.get(), {2, 8, 4, 6});
+
+  EXPECT_FALSE(sk_TEST_INT_is_sorted(sk.get()));
+  sk_TEST_INT_sort(sk.get());
+  ExpectStackEquals(sk.get(), {2, 4, 6, 8});
+  EXPECT_TRUE(sk_TEST_INT_is_sorted(sk.get()));
+
+  // Keep only multiples of four.
+  d = 4;
+  sk_TEST_INT_delete_if(sk.get(), keep_only_multiples, &d);
+  ExpectStackEquals(sk.get(), {4, 8});
+
+  // Removing elements preserves the sorted bit.
+  EXPECT_TRUE(sk_TEST_INT_is_sorted(sk.get()));
+
+  // Delete everything.
+  d = 16;
+  sk_TEST_INT_delete_if(sk.get(), keep_only_multiples, &d);
+  ExpectStackEquals(sk.get(), {});
+  EXPECT_TRUE(sk_TEST_INT_is_sorted(sk.get()));
+}
diff --git a/include/openssl/stack.h b/include/openssl/stack.h
index 11761dc..cf03002 100644
--- a/include/openssl/stack.h
+++ b/include/openssl/stack.h
@@ -178,6 +178,17 @@
 // otherwise it returns NULL.
 SAMPLE *sk_SAMPLE_delete_ptr(STACK_OF(SAMPLE) *sk, const SAMPLE *p);
 
+// sk_SAMPLE_delete_if_func is the callback function for |sk_SAMPLE_delete_if|.
+// It should return one to remove |p| and zero to keep it.
+typedef int (*sk_SAMPLE_delete_if_func)(SAMPLE *p, void *data);
+
+// sk_SAMPLE_delete_if calls |func| with each element of |sk| and removes the
+// entries where |func| returned one. This function does not free or return
+// removed pointers so, if |sk| owns its contents, |func| should release the
+// pointers prior to returning one.
+void sk_SAMPLE_delete_if(STACK_OF(SAMPLE) *sk, sk_SAMPLE_delete_if_func func,
+                         void *data);
+
 // sk_SAMPLE_find find the first value in |sk| equal to |p|. |sk|'s comparison
 // function determines equality, or pointer equality if |sk| has no comparison
 // function.
@@ -260,6 +271,10 @@
 // in OpenSSL 1.1.1, so hopefully we can fix this compatibly.
 typedef int (*OPENSSL_sk_cmp_func)(const void **a, const void **b);
 
+// OPENSSL_sk_delete_if_func is the generic version of
+// |sk_SAMPLE_delete_if_func|.
+typedef int (*OPENSSL_sk_delete_if_func)(void *obj, void *data);
+
 // The following function types call the above type-erased signatures with the
 // true types.
 typedef void (*OPENSSL_sk_call_free_func)(OPENSSL_sk_free_func, void *);
@@ -267,6 +282,8 @@
 typedef int (*OPENSSL_sk_call_cmp_func)(OPENSSL_sk_cmp_func,
                                         const void *const *,
                                         const void *const *);
+typedef int (*OPENSSL_sk_call_delete_if_func)(OPENSSL_sk_delete_if_func, void *,
+                                              void *);
 
 // stack_st contains an array of pointers. It is not designed to be used
 // directly, rather the wrapper macros should be used.
@@ -300,6 +317,9 @@
 OPENSSL_EXPORT size_t sk_insert(_STACK *sk, void *p, size_t where);
 OPENSSL_EXPORT void *sk_delete(_STACK *sk, size_t where);
 OPENSSL_EXPORT void *sk_delete_ptr(_STACK *sk, const void *p);
+OPENSSL_EXPORT void sk_delete_if(_STACK *sk,
+                                 OPENSSL_sk_call_delete_if_func call_func,
+                                 OPENSSL_sk_delete_if_func func, void *data);
 OPENSSL_EXPORT int sk_find(const _STACK *sk, size_t *out_index, const void *p,
                            OPENSSL_sk_call_cmp_func call_cmp_func);
 OPENSSL_EXPORT void *sk_shift(_STACK *sk);
@@ -366,7 +386,8 @@
                                                                               \
   typedef void (*sk_##name##_free_func)(ptrtype);                             \
   typedef ptrtype (*sk_##name##_copy_func)(ptrtype);                          \
-  typedef int (*sk_##name##_cmp_func)(constptrtype *a, constptrtype *b);      \
+  typedef int (*sk_##name##_cmp_func)(constptrtype *, constptrtype *);        \
+  typedef int (*sk_##name##_delete_if_func)(ptrtype, void *);                 \
                                                                               \
   OPENSSL_INLINE void sk_##name##_call_free_func(                             \
       OPENSSL_sk_free_func free_func, void *ptr) {                            \
@@ -389,6 +410,11 @@
     return ((sk_##name##_cmp_func)cmp_func)(&a_ptr, &b_ptr);                  \
   }                                                                           \
                                                                               \
+  OPENSSL_INLINE int sk_##name##_call_delete_if_func(                         \
+      OPENSSL_sk_delete_if_func func, void *obj, void *data) {                \
+    return ((sk_##name##_delete_if_func)func)((ptrtype)obj, data);            \
+  }                                                                           \
+                                                                              \
   OPENSSL_INLINE STACK_OF(name) *sk_##name##_new(sk_##name##_cmp_func comp) { \
     return (STACK_OF(name) *)sk_new((OPENSSL_sk_cmp_func)comp);               \
   }                                                                           \
@@ -440,6 +466,12 @@
     return (ptrtype)sk_delete_ptr((_STACK *)sk, (const void *)p);             \
   }                                                                           \
                                                                               \
+  OPENSSL_INLINE void sk_##name##_delete_if(                                  \
+      STACK_OF(name) *sk, sk_##name##_delete_if_func func, void *data) {      \
+    sk_delete_if((_STACK *)sk, sk_##name##_call_delete_if_func,               \
+                 (OPENSSL_sk_delete_if_func)func, data);                      \
+  }                                                                           \
+                                                                              \
   OPENSSL_INLINE int sk_##name##_find(const STACK_OF(name) *sk,               \
                                       size_t *out_index, constptrtype p) {    \
     return sk_find((const _STACK *)sk, out_index, (const void *)p,            \