pw_sync_threadx: Disable thread switching in the ISL

Updates the pw::sync::InterruptSpinLock implementation for ThreadX
to disable thread switching if not in an interrupt by raising the
preemption threshold. This way there's no risk that the current
thread gets switched out to another which recursively attempts to
grab the same lock.

This also changes the internal state to use a plain bool instead
of an atomic bool as an atomic is not necessary.

Change-Id: Id02869e66abaee956d1275dc181d1c14777c9cbe
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/50463
Reviewed-by: Wyatt Hepler <hepler@google.com>
Commit-Queue: Ewout van Bekkum <ewout@google.com>
Pigweed-Auto-Submit: Ewout van Bekkum <ewout@google.com>
diff --git a/pw_sync_threadx/BUILD.gn b/pw_sync_threadx/BUILD.gn
index b02d757..afeebe3 100644
--- a/pw_sync_threadx/BUILD.gn
+++ b/pw_sync_threadx/BUILD.gn
@@ -157,6 +157,7 @@
   sources = [ "interrupt_spin_lock.cc" ]
   deps = [
     "$dir_pw_assert",
+    "$dir_pw_interrupt:context",
     "$dir_pw_sync:interrupt_spin_lock.facade",
   ]
 }
diff --git a/pw_sync_threadx/interrupt_spin_lock.cc b/pw_sync_threadx/interrupt_spin_lock.cc
index bb73d4e..96b39ca 100644
--- a/pw_sync_threadx/interrupt_spin_lock.cc
+++ b/pw_sync_threadx/interrupt_spin_lock.cc
@@ -15,42 +15,120 @@
 #include "pw_sync/interrupt_spin_lock.h"
 
 #include "pw_assert/check.h"
+#include "pw_interrupt/context.h"
 #include "tx_api.h"
 
 namespace pw::sync {
+namespace {
+
+using State = backend::NativeInterruptSpinLock::State;
+
+}  // namespace
 
 void InterruptSpinLock::lock() {
   // In order to be pw::sync::InterruptSpinLock compliant, mask the interrupts
   // before attempting to grab the internal spin lock.
   native_type_.saved_interrupt_mask = tx_interrupt_control(TX_INT_DISABLE);
 
+  const bool in_interrupt = interrupt::InInterruptContext();
+
+  // Disable thread switching to ensure kernel APIs cannot switch to other
+  // threads which could then end up deadlocking recursively on this same lock.
+  if (!in_interrupt) {
+    TX_THREAD* current_thread = tx_thread_identify();
+    // During init, i.e. tx_application_define, there may not be a thread yet.
+    if (current_thread != nullptr) {
+      // Disable thread switching by raising the preemption threshold to the
+      // highest priority value of 0.
+      UINT preemption_success = tx_thread_preemption_change(
+          tx_thread_identify(), 0, &native_type_.saved_preemption_threshold);
+      PW_DCHECK_UINT_EQ(
+          TX_SUCCESS, preemption_success, "Failed to disable thread switching");
+    }
+  }
+
   // This implementation is not set up to support SMP, meaning we cannot
   // deadlock here due to the global interrupt lock, so we crash on recursion
   // on a specific spinlock instead.
-  PW_CHECK(!native_type_.locked.load(std::memory_order_relaxed),
-           "Recursive InterruptSpinLock::lock() detected");
+  PW_CHECK_UINT_EQ(native_type_.state,
+                   State::kUnlocked,
+                   "Recursive InterruptSpinLock::lock() detected");
 
-  native_type_.locked.store(true, std::memory_order_relaxed);
+  native_type_.state =
+      in_interrupt ? State::kLockedFromInterrupt : State::kLockedFromThread;
 }
 
 bool InterruptSpinLock::try_lock() {
   // In order to be pw::sync::InterruptSpinLock compliant, mask the interrupts
   // before attempting to grab the internal spin lock.
-  UINT saved_interrupt_mask = tx_interrupt_control(TX_INT_DISABLE);
+  const UINT saved_interrupt_mask = tx_interrupt_control(TX_INT_DISABLE);
 
-  if (native_type_.locked.load(std::memory_order_relaxed)) {
-    // Already locked, restore interrupts and bail out.
+  const bool in_interrupt = interrupt::InInterruptContext();
+
+  // Disable thread switching to ensure kernel APIs cannot switch to other
+  // threads which could then end up deadlocking recursively on this same lock.
+  if (!in_interrupt) {
+    TX_THREAD* current_thread = tx_thread_identify();
+    // During init, i.e. tx_application_define, there may not be a thread yet.
+    if (current_thread != nullptr) {
+      // Disable thread switching by raising the preemption threshold to the
+      // highest priority value of 0.
+      UINT preemption_success = tx_thread_preemption_change(
+          tx_thread_identify(), 0, &native_type_.saved_preemption_threshold);
+      PW_DCHECK_UINT_EQ(
+          TX_SUCCESS, preemption_success, "Failed to disable thread switching");
+    }
+  }
+
+  if (native_type_.state != State::kUnlocked) {
+    // Already locked, restore interrupts and thread switching and bail out.
+    if (!interrupt::InInterruptContext()) {
+      // Restore thread switching.
+      UINT unused = 0;
+      UINT preemption_success =
+          tx_thread_preemption_change(tx_thread_identify(),
+                                      native_type_.saved_preemption_threshold,
+                                      &unused);
+      PW_DCHECK_UINT_EQ(
+          TX_SUCCESS, preemption_success, "Failed to restore thread switching");
+    }
     tx_interrupt_control(saved_interrupt_mask);
     return false;
   }
 
   native_type_.saved_interrupt_mask = saved_interrupt_mask;
-  native_type_.locked.store(true, std::memory_order_relaxed);
+  native_type_.state =
+      in_interrupt ? State::kLockedFromInterrupt : State::kLockedFromThread;
   return true;
 }
 
 void InterruptSpinLock::unlock() {
-  native_type_.locked.store(false, std::memory_order_relaxed);
+  const bool in_interrupt = interrupt::InInterruptContext();
+
+  const State expected_state =
+      in_interrupt ? State::kLockedFromInterrupt : State::kLockedFromThread;
+  PW_CHECK_UINT_EQ(
+      native_type_.state,
+      expected_state,
+      "InterruptSpinLock::unlock() was called from a different context "
+      "compared to the lock()");
+
+  native_type_.state = State::kUnlocked;
+
+  if (!in_interrupt) {
+    TX_THREAD* current_thread = tx_thread_identify();
+    // During init, i.e. tx_application_define, there may not be a thread yet.
+    if (current_thread != nullptr) {
+      // Restore thread switching.
+      UINT unused = 0;
+      UINT preemption_success =
+          tx_thread_preemption_change(tx_thread_identify(),
+                                      native_type_.saved_preemption_threshold,
+                                      &unused);
+      PW_DCHECK_UINT_EQ(
+          TX_SUCCESS, preemption_success, "Failed to restore thread switching");
+    }
+  }
   tx_interrupt_control(native_type_.saved_interrupt_mask);
 }
 
diff --git a/pw_sync_threadx/public/pw_sync_threadx/interrupt_spin_lock_inline.h b/pw_sync_threadx/public/pw_sync_threadx/interrupt_spin_lock_inline.h
index 8bbd4cf..f66dc22 100644
--- a/pw_sync_threadx/public/pw_sync_threadx/interrupt_spin_lock_inline.h
+++ b/pw_sync_threadx/public/pw_sync_threadx/interrupt_spin_lock_inline.h
@@ -18,7 +18,9 @@
 namespace pw::sync {
 
 constexpr InterruptSpinLock::InterruptSpinLock()
-    : native_type_{.locked{false}, .saved_interrupt_mask = 0} {}
+    : native_type_{.state = backend::NativeInterruptSpinLock::State::kUnlocked,
+                   .saved_interrupt_mask = 0,
+                   .saved_preemption_threshold = 0} {}
 
 inline InterruptSpinLock::native_handle_type
 InterruptSpinLock::native_handle() {
diff --git a/pw_sync_threadx/public/pw_sync_threadx/interrupt_spin_lock_native.h b/pw_sync_threadx/public/pw_sync_threadx/interrupt_spin_lock_native.h
index cf6ceb8..567807f 100644
--- a/pw_sync_threadx/public/pw_sync_threadx/interrupt_spin_lock_native.h
+++ b/pw_sync_threadx/public/pw_sync_threadx/interrupt_spin_lock_native.h
@@ -13,15 +13,19 @@
 // the License.
 #pragma once
 
-#include <atomic>
-
 #include "tx_api.h"
 
 namespace pw::sync::backend {
 
 struct NativeInterruptSpinLock {
-  std::atomic<bool> locked;  // Used to detect recursion.
+  enum class State {
+    kUnlocked = 1,
+    kLockedFromInterrupt = 2,
+    kLockedFromThread = 3,
+  };
+  State state;  // Used to detect recursion and interrupt context escapes.
   UINT saved_interrupt_mask;
+  UINT saved_preemption_threshold;
 };
 using NativeInterruptSpinLockHandle = NativeInterruptSpinLock&;