[K/N] try to reimplement safe points with a load-branch-call
diff --git a/kotlin-native/runtime/src/gc/cms/cpp/ConcurrentMarkAndSweep.cpp b/kotlin-native/runtime/src/gc/cms/cpp/ConcurrentMarkAndSweep.cpp
index 1a535545..582e965 100644
--- a/kotlin-native/runtime/src/gc/cms/cpp/ConcurrentMarkAndSweep.cpp
+++ b/kotlin-native/runtime/src/gc/cms/cpp/ConcurrentMarkAndSweep.cpp
@@ -61,7 +61,7 @@
 
 void gc::ConcurrentMarkAndSweep::ThreadData::SafePointAllocation(size_t size) noexcept {
     gcScheduler_.OnSafePointAllocation(size);
-    mm::SuspendIfRequested();
+    mm::SafePoint();
 }
 
 void gc::ConcurrentMarkAndSweep::ThreadData::Schedule() noexcept {
@@ -105,6 +105,10 @@
     gc::Mark<internal::MarkTraits>(handle, markQueue);
 }
 
+mm::ThreadData& gc::ConcurrentMarkAndSweep::ThreadData::CommonThreadData() noexcept {
+    return threadData_;
+}
+
 #ifndef CUSTOM_ALLOCATOR
 gc::ConcurrentMarkAndSweep::ConcurrentMarkAndSweep(
         mm::ObjectFactory<ConcurrentMarkAndSweep>& objectFactory, GCScheduler& gcScheduler) noexcept :
diff --git a/kotlin-native/runtime/src/gc/cms/cpp/ConcurrentMarkAndSweep.hpp b/kotlin-native/runtime/src/gc/cms/cpp/ConcurrentMarkAndSweep.hpp
index c2bc62e..ae95a40 100644
--- a/kotlin-native/runtime/src/gc/cms/cpp/ConcurrentMarkAndSweep.hpp
+++ b/kotlin-native/runtime/src/gc/cms/cpp/ConcurrentMarkAndSweep.hpp
@@ -98,6 +98,8 @@
 
         Allocator CreateAllocator() noexcept { return Allocator(gc::Allocator(), *this); }
 
+        mm::ThreadData& CommonThreadData() noexcept;
+
     private:
         friend ConcurrentMarkAndSweep;
         ConcurrentMarkAndSweep& gc_;
diff --git a/kotlin-native/runtime/src/gc/cms/cpp/GCImpl.cpp b/kotlin-native/runtime/src/gc/cms/cpp/GCImpl.cpp
index 194cf29..c0dfec7 100644
--- a/kotlin-native/runtime/src/gc/cms/cpp/GCImpl.cpp
+++ b/kotlin-native/runtime/src/gc/cms/cpp/GCImpl.cpp
@@ -17,7 +17,7 @@
 
 ALWAYS_INLINE void SafePointRegular(gc::GC::ThreadData& threadData, size_t weight) noexcept {
     threadData.impl().gcScheduler().OnSafePointRegular(weight);
-    mm::SuspendIfRequested();
+    mm::SafePoint();
 }
 
 } // namespace
diff --git a/kotlin-native/runtime/src/gc/stms/cpp/SameThreadMarkAndSweep.cpp b/kotlin-native/runtime/src/gc/stms/cpp/SameThreadMarkAndSweep.cpp
index d27016c..542600f 100644
--- a/kotlin-native/runtime/src/gc/stms/cpp/SameThreadMarkAndSweep.cpp
+++ b/kotlin-native/runtime/src/gc/stms/cpp/SameThreadMarkAndSweep.cpp
@@ -69,8 +69,8 @@
     auto didGC = gc_.PerformFullGC();
 
     if (!didGC) {
-        // If we failed to suspend threads, someone else might be asking to suspend them.
-        mm::SuspendIfRequested();
+        // If we failed to suspend threads, someone else might be asking for some action at safe point
+        mm::SafePoint();
     }
 }
 
@@ -84,8 +84,7 @@
         case SafepointFlag::kNone:
             RuntimeAssert(false, "Must've been handled by the caller");
             return;
-        case SafepointFlag::kNeedsSuspend:
-            mm::SuspendIfRequested();
+        case SafepointFlag::kNeedsSuspend:mm::SafePoint();
             return;
         case SafepointFlag::kNeedsGC:
             RuntimeLogDebug({kTagGC}, "Attempt to GC at SafePoint");
diff --git a/kotlin-native/runtime/src/mm/cpp/SafePoint.hpp b/kotlin-native/runtime/src/mm/cpp/SafePoint.hpp
new file mode 100644
index 0000000..d24fc4e
--- /dev/null
+++ b/kotlin-native/runtime/src/mm/cpp/SafePoint.hpp
@@ -0,0 +1,37 @@
+#pragma once
+
+#include <atomic>
+
+#include "ThreadData.hpp"
+
+namespace kotlin::mm {
+
+using SafePointAction = void(*)(mm::ThreadData&);
+
+bool TrySetSafePointAction(SafePointAction action) noexcept;
+void UnsetSafePointAction() noexcept;
+bool IsSafePointActionRequested() noexcept;
+
+void SafePoint() noexcept;
+void SafePoint(ThreadData& threadData) noexcept;
+
+class ScopedSafePointAction : private MoveOnly {
+public:
+    explicit ScopedSafePointAction(SafePointAction action) noexcept : actionSet_(TrySetSafePointAction(action)) {}
+    ScopedSafePointAction(SafePointAction action, const char* reasonToForce) noexcept : ScopedSafePointAction(action) {
+        RuntimeAssert(actionSet(), "%s", reasonToForce);
+    }
+    ~ScopedSafePointAction() noexcept {
+        if (actionSet_) {
+            UnsetSafePointAction();
+        }
+    }
+    bool actionSet() const {
+        return actionSet_;
+    }
+private:
+    bool actionSet_;
+};
+
+}
+
diff --git a/kotlin-native/runtime/src/mm/cpp/SafePoints.cpp b/kotlin-native/runtime/src/mm/cpp/SafePoints.cpp
new file mode 100644
index 0000000..64564cf
--- /dev/null
+++ b/kotlin-native/runtime/src/mm/cpp/SafePoints.cpp
@@ -0,0 +1,48 @@
+#include "SafePoint.hpp"
+#include "ThreadData.hpp"
+#include "KAssert.h"
+
+namespace {
+
+std::atomic<kotlin::mm::SafePointAction> safePointAction = nullptr;
+
+ALWAYS_INLINE void slowPath(kotlin::mm::ThreadData& threadData) {
+    // reread an action to avoid register pollution outside the function
+    auto action = safePointAction.load(std::memory_order_seq_cst);
+    if (action != nullptr) {
+        action(threadData);
+    }
+}
+
+NO_INLINE void slowPath() {
+    slowPath(*kotlin::mm::ThreadRegistry::Instance().CurrentThreadData());
+}
+
+} // namespace
+
+bool kotlin::mm::TrySetSafePointAction(kotlin::mm::SafePointAction action) noexcept {
+    kotlin::mm::SafePointAction expected = nullptr;
+    kotlin::mm::SafePointAction desired = action;
+    return safePointAction.compare_exchange_strong(expected, desired, std::memory_order_seq_cst);
+}
+
+void kotlin::mm::UnsetSafePointAction() noexcept {
+    RuntimeAssert(kotlin::mm::IsSafePointActionRequested(), "Some safe point action must be set");
+    safePointAction.store(nullptr, std::memory_order_seq_cst);
+}
+
+bool kotlin::mm::IsSafePointActionRequested() noexcept {
+    return safePointAction.load(std::memory_order_relaxed) != nullptr;
+}
+
+ALWAYS_INLINE void kotlin::mm::SafePoint() noexcept {
+    // TODO nullable action VS flag + action call
+    auto action = safePointAction.load(std::memory_order_relaxed);
+    if (action != nullptr) /*[[unlikely]]*/ {
+        slowPath();
+    }
+}
+
+void kotlin::mm::SafePoint(mm::ThreadData& threadData) noexcept {
+    slowPath(threadData);
+}
diff --git a/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.cpp b/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.cpp
index 5d9cffe..b61b308 100644
--- a/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.cpp
+++ b/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.cpp
@@ -15,17 +15,17 @@
 
 namespace {
 
-bool isSuspendedOrNative(kotlin::mm::ThreadData& thread) noexcept {
-    auto& suspensionData = thread.suspensionData();
-    return suspensionData.suspended() || suspensionData.state() == kotlin::ThreadState::kNative;
-}
+//bool isSuspendedOrNative(kotlin::mm::ThreadData& thread) noexcept {
+//    auto& suspensionData = thread.suspensionData();
+//    return suspensionData.suspended() || suspensionData.state() == kotlin::ThreadState::kNative;
+//}
 
 template<typename F>
 bool allThreads(F predicate) noexcept {
     auto& threadRegistry = kotlin::mm::ThreadRegistry::Instance();
     auto* currentThread = (threadRegistry.IsCurrentThreadRegistered())
-            ? threadRegistry.CurrentThreadData()
-            : nullptr;
+                          ? threadRegistry.CurrentThreadData()
+                          : nullptr;
     kotlin::mm::ThreadRegistry::Iterable threads = kotlin::mm::ThreadRegistry::Instance().LockForIter();
     for (auto& thread : threads) {
         // Handle if suspension was initiated by the mutator thread.
@@ -50,7 +50,15 @@
 
 std::atomic<bool> kotlin::mm::internal::gSuspensionRequested = false;
 
-NO_EXTERNAL_CALLS_CHECK void kotlin::mm::ThreadSuspensionData::suspendIfRequestedSlowPath() noexcept {
+kotlin::ThreadState kotlin::mm::ThreadSuspensionData::setState(kotlin::ThreadState newState) noexcept {
+    ThreadState oldState = state_.exchange(newState);
+    if (oldState == ThreadState::kNative && newState == ThreadState::kRunnable) {
+        SafePoint(threadData_); // must use already acquired thread data // TODO explain why
+    }
+    return oldState;
+}
+
+NO_EXTERNAL_CALLS_CHECK void kotlin::mm::ThreadSuspensionData::suspendIfRequested() noexcept {
     if (IsThreadSuspensionRequested()) {
         threadData_.gc().OnSuspendForGC();
         std::unique_lock lock(gSuspensionMutex);
@@ -69,11 +77,12 @@
     RuntimeAssert(gSuspensionRequestedByCurrentThread == false, "Current thread already suspended threads.");
     {
         std::unique_lock lock(gSuspensionMutex);
-        bool actual = false;
-        internal::gSuspensionRequested.compare_exchange_strong(actual, true);
-        if (actual) {
-            return false;
-        }
+        bool safePointSet = mm::TrySetSafePointAction([](mm::ThreadData& threadData) {
+            threadData.suspensionData().suspendIfRequested();
+        });
+        if (!safePointSet) return false;
+        RuntimeAssert(!IsThreadSuspensionRequested(), "Suspension must not be requested without altering safe point action");
+        internal::gSuspensionRequested = true;
     }
     gSuspensionRequestedByCurrentThread = true;
 
@@ -82,21 +91,11 @@
 
 void kotlin::mm::WaitForThreadsSuspension() noexcept {
     // Spin waiting for threads to suspend. Ignore Native threads.
-    while(!allThreads(isSuspendedOrNative)) {
+    while(!allThreads([] (mm::ThreadData& thread) { return thread.suspensionData().suspendedOrNative(); })) {
         yield();
     }
 }
 
-NO_INLINE void kotlin::mm::SuspendIfRequestedSlowPath() noexcept {
-    mm::ThreadRegistry::Instance().CurrentThreadData()->suspensionData().suspendIfRequestedSlowPath();
-}
-
-ALWAYS_INLINE void kotlin::mm::SuspendIfRequested() noexcept {
-    if (IsThreadSuspensionRequested()) {
-        SuspendIfRequestedSlowPath();
-    }
-}
-
 void kotlin::mm::ResumeThreads() noexcept {
     // From the std::condition_variable docs:
     // Even if the shared variable is atomic, it must be modified under
@@ -105,6 +104,7 @@
     {
         std::unique_lock lock(gSuspensionMutex);
         internal::gSuspensionRequested = false;
+        mm::UnsetSafePointAction();
     }
     gSuspensionRequestedByCurrentThread = false;
     gSuspensionCondVar.notify_all();
diff --git a/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.hpp b/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.hpp
index fe2c8af..cd1c74a 100644
--- a/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.hpp
+++ b/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.hpp
@@ -9,6 +9,7 @@
 #include <atomic>
 
 #include "Memory.h"
+#include "SafePoint.hpp"
 
 namespace kotlin {
 namespace mm {
@@ -34,35 +35,21 @@
 
     ThreadState state() noexcept { return state_; }
 
-    ThreadState setState(ThreadState newState) noexcept {
-        ThreadState oldState = state_.exchange(newState);
-        if (oldState == ThreadState::kNative && newState == ThreadState::kRunnable) {
-            suspendIfRequested();
-        }
-        return oldState;
-    }
+    ThreadState setState(ThreadState newState) noexcept;
 
     bool suspended() noexcept { return suspended_; }
+    bool suspendedOrNative() noexcept { return suspended() || state() == kotlin::ThreadState::kNative; }
 
-    NO_EXTERNAL_CALLS_CHECK void suspendIfRequested() noexcept {
-        if (IsThreadSuspensionRequested()) {
-            suspendIfRequestedSlowPath();
-        }
-    }
+    void suspendIfRequested() noexcept;
 
 private:
-    friend void SuspendIfRequestedSlowPath() noexcept;
-
     std::atomic<ThreadState> state_;
     mm::ThreadData& threadData_;
     std::atomic<bool> suspended_;
-    void suspendIfRequestedSlowPath() noexcept;
 };
 
 bool RequestThreadsSuspension() noexcept;
 void WaitForThreadsSuspension() noexcept;
-void SuspendIfRequestedSlowPath() noexcept;
-void SuspendIfRequested() noexcept;
 
 /**
  * Suspends all threads registered in ThreadRegistry except threads that are in the Native state.
@@ -84,6 +71,24 @@
  */
 void ResumeThreads() noexcept;
 
+class ScopedSTWRequest : private MoveOnly {
+public:
+    ScopedSTWRequest() noexcept : suspensionRequested_(RequestThreadsSuspension()) {}
+    explicit ScopedSTWRequest(const char* reasonToForce) noexcept : ScopedSTWRequest() {
+            RuntimeAssert(suspensionRequested(), "%s", reasonToForce);
+    }
+    ~ScopedSTWRequest() noexcept {
+        if (suspensionRequested_) {
+            ResumeThreads();
+        }
+    }
+    bool suspensionRequested() const {
+        return suspensionRequested_;
+    }
+private:
+    bool suspensionRequested_;
+};
+
 } // namespace mm
 } // namespace kotlin
 
diff --git a/kotlin-native/runtime/src/mm/cpp/ThreadSuspensionTest.cpp b/kotlin-native/runtime/src/mm/cpp/ThreadSuspensionTest.cpp
index 800d2e8..064f518 100644
--- a/kotlin-native/runtime/src/mm/cpp/ThreadSuspensionTest.cpp
+++ b/kotlin-native/runtime/src/mm/cpp/ThreadSuspensionTest.cpp
@@ -228,7 +228,7 @@
         EXPECT_EQ(GetThreadState(), ThreadState::kRunnable);
         // Give other threads a chance to call CallInitGlobalPossiblyLock.
         std::this_thread::yield();
-        mm::SuspendIfRequested();
+        mm::SafePoint();
     });
 
     for (size_t i = 0; i < kThreadCount; i++) {
@@ -241,7 +241,7 @@
 
             CallInitGlobalPossiblyLock(&lock, initializationFunction);
             // Try to suspend to handle a case when this thread doesn't call the initialization function.
-            mm::SuspendIfRequested();
+            mm::SafePoint();
         });
     }
     waitUntilThreadsAreReady();