ztest: Add zassume* API

Add an assume API which works like JUnit's. Assumptions can be made
at any point before your test returns (setup, before, and during the
test). If an assumption fails, the test will be marked as skipped.

This helps avoid a cascading affect of failed tests when a base
feature breaks. The feature is expected to have tests and the tests
which depend on it should be skipped (not failed) if that feature
is broken.

Issue #42472

Signed-off-by: Yuval Peress <peress@google.com>
diff --git a/doc/develop/test/ztest.rst b/doc/develop/test/ztest.rst
index ebfd2bd..8cba769 100644
--- a/doc/develop/test/ztest.rst
+++ b/doc/develop/test/ztest.rst
@@ -472,6 +472,26 @@
 
 .. doxygengroup:: ztest_assert
 
+Assumptions
+===========
+
+These macros will instantly skip the test or suite if the related assumption fails.
+When an assumption fails, it will print the current file, line, and function,
+alongside a reason for the failure and an optional message. If the config
+option:`CONFIG_ZTEST_ASSERT_VERBOSE` is 0, the assumptions will only print the
+file and line numbers, reducing the binary size of the test.
+
+Example output for a failed macro from
+``zassume_equal(buf->ref, 2, "Invalid refcount")``:
+
+.. code-block::none
+
+    START - test_get_single_buffer
+        Assumption failed at main.c:62: test_get_single_buffer: Invalid refcount (buf->ref not equal to 2)
+     SKIP - test_get_single_buffer in 0.0 seconds
+
+.. doxygengroup:: ztest_assume
+
 Mocking
 =======
 
diff --git a/subsys/testsuite/ztest/include/zephyr/ztest_assert.h b/subsys/testsuite/ztest/include/zephyr/ztest_assert.h
index 7057623..f96d225 100644
--- a/subsys/testsuite/ztest/include/zephyr/ztest_assert.h
+++ b/subsys/testsuite/ztest/include/zephyr/ztest_assert.h
@@ -26,6 +26,7 @@
 
 const char *ztest_relative_filename(const char *file);
 void ztest_test_fail(void);
+void ztest_test_skip(void);
 #if CONFIG_ZTEST_ASSERT_VERBOSE == 0
 
 static inline bool z_zassert_(bool cond, const char *file, int line)
@@ -41,6 +42,19 @@
 
 #define z_zassert(cond, default_msg, file, line, func, msg, ...) z_zassert_(cond, file, line)
 
+static inline bool z_zassume_(bool cond, const char *file, int line)
+{
+	if (cond == false) {
+		PRINT("\n    Assumption failed at %s:%d\n", ztest_relative_filename(file), line);
+		ztest_test_skip();
+		return false;
+	}
+
+	return true;
+}
+
+#define z_zassume(cond, default_msg, file, line, func, msg, ...) z_zassume_(cond, file, line)
+
 #else /* CONFIG_ZTEST_ASSERT_VERBOSE != 0 */
 
 static inline bool z_zassert(bool cond, const char *default_msg, const char *file, int line,
@@ -67,6 +81,30 @@
 	return true;
 }
 
+static inline bool z_zassume(bool cond, const char *default_msg, const char *file, int line,
+			     const char *func, const char *msg, ...)
+{
+	if (cond == false) {
+		va_list vargs;
+
+		va_start(vargs, msg);
+		PRINT("\n    Assumption failed at %s:%d: %s: %s\n", ztest_relative_filename(file),
+		      line, func, default_msg);
+		vprintk(msg, vargs);
+		printk("\n");
+		va_end(vargs);
+		ztest_test_skip();
+		return false;
+	}
+#if CONFIG_ZTEST_ASSERT_VERBOSE == 2
+	else {
+		PRINT("\n   Assumption succeeded at %s:%d (%s)\n", ztest_relative_filename(file),
+		      line, func);
+	}
+#endif
+	return true;
+}
+
 #endif /* CONFIG_ZTEST_ASSERT_VERBOSE */
 
 /**
@@ -103,6 +141,16 @@
 		}                                                                                  \
 	} while (0)
 
+#define zassume(cond, default_msg, msg, ...)                                                       \
+	do {                                                                                       \
+		bool _ret = z_zassume(cond, msg ? ("(" default_msg ")") : (default_msg), __FILE__, \
+				      __LINE__, __func__, msg ? msg : "", ##__VA_ARGS__);          \
+		if (!_ret) {                                                                       \
+			/* If kernel but without multithreading return. */                         \
+			COND_CODE_1(KERNEL, (COND_CODE_1(CONFIG_MULTITHREADING, (), (return;))),   \
+				    ())                                                            \
+		}                                                                                  \
+	} while (0)
 /**
  * @brief Assert that this function call won't be reached
  * @param msg Optional message to print if the assertion fails
@@ -237,6 +285,167 @@
  * @}
  */
 
+/**
+ * @defgroup ztest_assume Ztest assumption macros
+ * @ingroup ztest
+ *
+ * This module provides assumptions when using Ztest.
+ *
+ * @{
+ */
+
+/**
+ * @brief Assume that @a cond is true
+ *
+ * If the assumption fails, the test will be marked as "skipped".
+ *
+ * @param cond Condition to check
+ * @param msg Optional message to print if the assumption fails
+ */
+#define zassume_true(cond, msg, ...) zassume(cond, #cond " is false", msg, ##__VA_ARGS__)
+
+/**
+ * @brief Assume that @a cond is false
+ *
+ * If the assumption fails, the test will be marked as "skipped".
+ *
+ * @param cond Condition to check
+ * @param msg Optional message to print if the assumption fails
+ */
+#define zassume_false(cond, msg, ...) zassume(!(cond), #cond " is true", msg, ##__VA_ARGS__)
+
+/**
+ * @brief Assume that @a cond is 0 (success)
+ *
+ * If the assumption fails, the test will be marked as "skipped".
+ *
+ * @param cond Condition to check
+ * @param msg Optional message to print if the assumption fails
+ */
+#define zassume_ok(cond, msg, ...) zassume(!(cond), #cond " is non-zero", msg, ##__VA_ARGS__)
+
+/**
+ * @brief Assume that @a ptr is NULL
+ *
+ * If the assumption fails, the test will be marked as "skipped".
+ *
+ * @param ptr Pointer to compare
+ * @param msg Optional message to print if the assumption fails
+ */
+#define zassume_is_null(ptr, msg, ...)                                                             \
+	zassume((ptr) == NULL, #ptr " is not NULL", msg, ##__VA_ARGS__)
+
+/**
+ * @brief Assume that @a ptr is not NULL
+ *
+ * If the assumption fails, the test will be marked as "skipped".
+ *
+ * @param ptr Pointer to compare
+ * @param msg Optional message to print if the assumption fails
+ */
+#define zassume_not_null(ptr, msg, ...) zassume((ptr) != NULL, #ptr " is NULL", msg, ##__VA_ARGS__)
+
+/**
+ * @brief Assume that @a a equals @a b
+ *
+ * @a a and @a b won't be converted and will be compared directly. If the
+ * assumption fails, the test will be marked as "skipped".
+ *
+ * @param a Value to compare
+ * @param b Value to compare
+ * @param msg Optional message to print if the assumption fails
+ */
+#define zassume_equal(a, b, msg, ...)                                                              \
+	zassume((a) == (b), #a " not equal to " #b, msg, ##__VA_ARGS__)
+
+/**
+ * @brief Assume that @a a does not equal @a b
+ *
+ * @a a and @a b won't be converted and will be compared directly. If the
+ * assumption fails, the test will be marked as "skipped".
+ *
+ * @param a Value to compare
+ * @param b Value to compare
+ * @param msg Optional message to print if the assumption fails
+ */
+#define zassume_not_equal(a, b, msg, ...)                                                          \
+	zassume((a) != (b), #a " equal to " #b, msg, ##__VA_ARGS__)
+
+/**
+ * @brief Assume that @a a equals @a b
+ *
+ * @a a and @a b will be converted to `void *` before comparing. If the
+ * assumption fails, the test will be marked as "skipped".
+ *
+ * @param a Value to compare
+ * @param b Value to compare
+ * @param msg Optional message to print if the assumption fails
+ */
+#define zassume_equal_ptr(a, b, msg, ...)                                                          \
+	zassume((void *)(a) == (void *)(b), #a " not equal to " #b, msg, ##__VA_ARGS__)
+
+/**
+ * @brief Assume that @a a is within @a b with delta @a d
+ *
+ * If the assumption fails, the test will be marked as "skipped".
+ *
+ * @param a Value to compare
+ * @param b Value to compare
+ * @param d Delta
+ * @param msg Optional message to print if the assumption fails
+ */
+#define zassume_within(a, b, d, msg, ...)                                                          \
+	zassume(((a) >= ((b) - (d))) && ((a) <= ((b) + (d))), #a " not within " #b " +/- " #d,     \
+		msg, ##__VA_ARGS__)
+
+/**
+ * @brief Assume that @a a is greater than or equal to @a l and less
+ *        than or equal to @a u
+ *
+ * If the assumption fails, the test will be marked as "skipped".
+ *
+ * @param a Value to compare
+ * @param l Lower limit
+ * @param u Upper limit
+ * @param msg Optional message to print if the assumption fails
+ */
+#define zassume_between_inclusive(a, l, u, msg, ...)                                               \
+	zassume(((a) >= (l)) && ((a) <= (u)), #a " not between " #l " and " #u " inclusive", msg,  \
+		##__VA_ARGS__)
+
+/**
+ * @brief Assume that 2 memory buffers have the same contents
+ *
+ * This macro calls the final memory comparison assumption macro.
+ * Using double expansion allows providing some arguments by macros that
+ * would expand to more than one values (ANSI-C99 defines that all the macro
+ * arguments have to be expanded before macro call).
+ *
+ * @param ... Arguments, see @ref zassume_mem_equal__
+ *            for real arguments accepted.
+ */
+#define zassume_mem_equal(...) zassume_mem_equal__(__VA_ARGS__)
+
+/**
+ * @brief Internal assume that 2 memory buffers have the same contents
+ *
+ * If the assumption fails, the test will be marked as "skipped".
+ *
+ * @note This is internal macro, to be used as a second expansion.
+ *       See @ref zassume_mem_equal.
+ *
+ * @param buf Buffer to compare
+ * @param exp Buffer with expected contents
+ * @param size Size of buffers
+ * @param msg Optional message to print if the assumption fails
+ */
+#define zassume_mem_equal__(buf, exp, size, msg, ...)                                              \
+	zassume(memcmp(buf, exp, size) == 0, #buf " not equal to " #exp, msg, ##__VA_ARGS__)
+
+/**
+ * @}
+ */
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/subsys/testsuite/ztest/src/ztest_new.c b/subsys/testsuite/ztest/src/ztest_new.c
index abd39a9..7cd7df4 100644
--- a/subsys/testsuite/ztest/src/ztest_new.c
+++ b/subsys/testsuite/ztest/src/ztest_new.c
@@ -345,7 +345,9 @@
 void ztest_test_skip(void)
 {
 	test_result = ZTEST_RESULT_SKIP;
-	test_finalize();
+	if (phase != TEST_PHASE_SETUP) {
+		test_finalize();
+	}
 }
 
 void ztest_simple_1cpu_before(void *data)
@@ -384,6 +386,9 @@
 
 	phase = TEST_PHASE_BEFORE;
 
+	/* If the suite's setup function marked us as skipped, don't bother
+	 * running the tests.
+	 */
 	if (IS_ENABLED(CONFIG_MULTITHREADING)) {
 		k_thread_create(&ztest_thread, ztest_thread_stack,
 				K_THREAD_STACK_SIZEOF(ztest_thread_stack),
@@ -395,9 +400,12 @@
 		if (test->name != NULL) {
 			k_thread_name_set(&ztest_thread, test->name);
 		}
-		k_thread_start(&ztest_thread);
-		k_thread_join(&ztest_thread, K_FOREVER);
-	} else {
+		/* Only start the thread if we're not skipping the suite */
+		if (test_result != ZTEST_RESULT_SKIP) {
+			k_thread_start(&ztest_thread);
+			k_thread_join(&ztest_thread, K_FOREVER);
+		}
+	} else if (test_result != ZTEST_RESULT_SKIP) {
 		test_result = ZTEST_RESULT_PENDING;
 		run_test_rules(/*is_before=*/true, test, data);
 		if (suite->before) {
@@ -499,6 +507,7 @@
 	init_testing();
 
 	TC_SUITE_START(suite->name);
+	test_result = ZTEST_RESULT_PENDING;
 	phase = TEST_PHASE_SETUP;
 	if (suite->setup != NULL) {
 		data = suite->setup();
diff --git a/tests/ztest/error_hook/src/main.c b/tests/ztest/error_hook/src/main.c
index 01f40a3..a25b55c 100644
--- a/tests/ztest/error_hook/src/main.c
+++ b/tests/ztest/error_hook/src/main.c
@@ -4,9 +4,10 @@
  * SPDX-License-Identifier: Apache-2.0
  */
 
-#include <ztest.h>
 #include <zephyr/irq_offload.h>
 #include <zephyr/syscall_handler.h>
+
+#include <ztest.h>
 #include <ztest_error_hook.h>
 
 #define STACK_SIZE (1024 + CONFIG_TEST_EXTRA_STACK_SIZE)
@@ -360,3 +361,60 @@
 	return NULL;
 }
 ZTEST_SUITE(error_hook_tests, NULL, error_hook_tests_setup, NULL, NULL, NULL);
+
+static void *fail_assume_in_setup_setup(void)
+{
+	/* Fail the assume, will skip all the tests */
+	zassume_true(false, NULL);
+	return NULL;
+}
+
+ZTEST_SUITE(fail_assume_in_setup, NULL, fail_assume_in_setup_setup, NULL, NULL, NULL);
+
+ZTEST(fail_assume_in_setup, test_to_skip0)
+{
+	/* This test should never be run */
+	ztest_test_fail();
+}
+
+ZTEST(fail_assume_in_setup, test_to_skip1)
+{
+	/* This test should never be run */
+	ztest_test_fail();
+}
+
+static void fail_assume_in_before_before(void *unused)
+{
+	ARG_UNUSED(unused);
+	zassume_true(false, NULL);
+}
+
+ZTEST_SUITE(fail_assume_in_before, NULL, NULL, fail_assume_in_before_before, NULL, NULL);
+
+ZTEST(fail_assume_in_before, test_to_skip0)
+{
+	/* This test should never be run */
+	ztest_test_fail();
+}
+
+ZTEST(fail_assume_in_before, test_to_skip1)
+{
+	/* This test should never be run */
+	ztest_test_fail();
+}
+
+ZTEST_SUITE(fail_assume_in_test, NULL, NULL, NULL, NULL, NULL);
+
+ZTEST(fail_assume_in_test, test_to_skip)
+{
+	zassume_true(false, NULL);
+	ztest_test_fail();
+}
+
+void test_main(void)
+{
+	ztest_run_test_suites(NULL);
+	/* Can't run ztest_verify_all_test_suites_ran() since some tests are
+	 * skipped by design.
+	 */
+}