[K/N][Runtime] Support suspending mutators in the new MM
diff --git a/kotlin-native/runtime/build.gradle.kts b/kotlin-native/runtime/build.gradle.kts
index 591e469..2334fa9 100644
--- a/kotlin-native/runtime/build.gradle.kts
+++ b/kotlin-native/runtime/build.gradle.kts
@@ -272,6 +272,12 @@
})
}
+val hostAssemble by tasks.registering {
+ dependsOn(tasks.withType(CompileToBitcode::class).matching {
+ it.outputGroup == "main" && it.target == hostName
+ })
+}
+
val clean by tasks.registering {
doFirst {
delete(buildDir)
diff --git a/kotlin-native/runtime/src/legacymm/cpp/Memory.cpp b/kotlin-native/runtime/src/legacymm/cpp/Memory.cpp
index faa0166..c28e24e 100644
--- a/kotlin-native/runtime/src/legacymm/cpp/Memory.cpp
+++ b/kotlin-native/runtime/src/legacymm/cpp/Memory.cpp
@@ -3733,6 +3733,10 @@
// no-op, used by the new MM only.
}
+ALWAYS_INLINE void kotlin::AssertThreadState(MemoryState* thread, std::initializer_list<ThreadState> expected) noexcept {
+ // no-op, used by the new MM only.
+}
+
MemoryState* kotlin::mm::GetMemoryState() {
return ::memoryState;
}
diff --git a/kotlin-native/runtime/src/main/cpp/Memory.h b/kotlin-native/runtime/src/main/cpp/Memory.h
index b22c6f3..586c79e 100644
--- a/kotlin-native/runtime/src/main/cpp/Memory.h
+++ b/kotlin-native/runtime/src/main/cpp/Memory.h
@@ -17,6 +17,7 @@
#ifndef RUNTIME_MEMORY_H
#define RUNTIME_MEMORY_H
+#include <ostream>
#include <utility>
#include "KAssert.h"
@@ -410,12 +411,17 @@
// Asserts that the given thread is in the given state.
ALWAYS_INLINE void AssertThreadState(MemoryState* thread, ThreadState expected) noexcept;
+ALWAYS_INLINE void AssertThreadState(MemoryState* thread, std::initializer_list<ThreadState> expected) noexcept;
// Asserts that the current thread is in the the given state.
ALWAYS_INLINE inline void AssertThreadState(ThreadState expected) noexcept {
AssertThreadState(mm::GetMemoryState(), expected);
}
+ALWAYS_INLINE inline void AssertThreadState(std::initializer_list<ThreadState> expected) noexcept {
+ AssertThreadState(mm::GetMemoryState(), expected);
+}
+
// Scopely sets the given thread state for the given thread.
class ThreadStateGuard final : private Pinned {
public:
diff --git a/kotlin-native/runtime/src/main/cpp/Utils.hpp b/kotlin-native/runtime/src/main/cpp/Utils.hpp
index 1f10f53..a55bd75 100644
--- a/kotlin-native/runtime/src/main/cpp/Utils.hpp
+++ b/kotlin-native/runtime/src/main/cpp/Utils.hpp
@@ -6,6 +6,8 @@
#ifndef RUNTIME_UTILS_H
#define RUNTIME_UTILS_H
+#include <type_traits>
+
namespace kotlin {
// A helper for implementing classes with disabled copy constructor and copy assignment.
@@ -52,6 +54,28 @@
~Pinned() = default;
};
+// A helper that scopley assings a value to a variable. The variable will
+// be set to its original value upon destruction of the AutoReset instance.
+// Note that an AutoReset instance must have a shorter lifetime than
+// the variable it works with to avoid invalid memory access.
+template<typename T1, typename T2>
+class AutoReset final : private Pinned {
+ static_assert(std::is_assignable<T1, T2>::value);
+
+public:
+ AutoReset(T1* variable, T2 value) : variable_(variable), oldValue_(*variable) {
+ *variable_ = value;
+ }
+
+ ~AutoReset() {
+ *variable_ = oldValue_;
+ }
+
+private:
+ T1* variable_;
+ T2 oldValue_;
+};
+
} // namespace kotlin
#endif // RUNTIME_UTILS_H
diff --git a/kotlin-native/runtime/src/mm/cpp/ThreadData.hpp b/kotlin-native/runtime/src/mm/cpp/ThreadData.hpp
index c994561..b6407c4 100644
--- a/kotlin-native/runtime/src/mm/cpp/ThreadData.hpp
+++ b/kotlin-native/runtime/src/mm/cpp/ThreadData.hpp
@@ -18,6 +18,7 @@
#include "ThreadLocalStorage.hpp"
#include "Types.h"
#include "Utils.hpp"
+#include "ThreadSuspension.hpp"
struct ObjHeader;
@@ -32,9 +33,9 @@
threadId_(threadId),
globalsThreadQueue_(GlobalsRegistry::Instance()),
stableRefThreadQueue_(StableRefRegistry::Instance()),
- state_(ThreadState::kRunnable),
gc_(GlobalData::Instance().gc()),
- objectFactoryThreadQueue_(GlobalData::Instance().objectFactory(), gc_) {}
+ objectFactoryThreadQueue_(GlobalData::Instance().objectFactory(), gc_),
+ suspensionData_(ThreadState::kRunnable) {}
~ThreadData() = default;
@@ -46,9 +47,9 @@
StableRefRegistry::ThreadQueue& stableRefThreadQueue() noexcept { return stableRefThreadQueue_; }
- ThreadState state() noexcept { return state_; }
+ ThreadState state() noexcept { return suspensionData_.state(); }
- ThreadState setState(ThreadState state) noexcept { return state_.exchange(state); }
+ ThreadState setState(ThreadState state) noexcept { return suspensionData_.setState(state); }
ObjectFactory<gc::GC>::ThreadQueue& objectFactoryThreadQueue() noexcept { return objectFactoryThreadQueue_; }
@@ -58,6 +59,8 @@
gc::GC::ThreadData& gc() noexcept { return gc_; }
+ ThreadSuspensionData& suspensionData() { return suspensionData_; }
+
void Publish() noexcept {
// TODO: These use separate locks, which is inefficient.
globalsThreadQueue_.Publish();
@@ -76,11 +79,11 @@
GlobalsRegistry::ThreadQueue globalsThreadQueue_;
ThreadLocalStorage tls_;
StableRefRegistry::ThreadQueue stableRefThreadQueue_;
- std::atomic<ThreadState> state_;
ShadowStack shadowStack_;
gc::GC::ThreadData gc_;
ObjectFactory<gc::GC>::ThreadQueue objectFactoryThreadQueue_;
KStdVector<std::pair<ObjHeader**, ObjHeader*>> initializingSingletons_;
+ ThreadSuspensionData suspensionData_;
};
} // namespace mm
diff --git a/kotlin-native/runtime/src/mm/cpp/ThreadState.cpp b/kotlin-native/runtime/src/mm/cpp/ThreadState.cpp
index 8486ca4..b4c0064 100644
--- a/kotlin-native/runtime/src/mm/cpp/ThreadState.cpp
+++ b/kotlin-native/runtime/src/mm/cpp/ThreadState.cpp
@@ -7,7 +7,7 @@
#include "ThreadData.hpp"
#include "ThreadState.hpp"
-const char* kotlin::internal::stateToString(ThreadState state) noexcept {
+const char* kotlin::ThreadStateName(ThreadState state) noexcept {
switch (state) {
case ThreadState::kRunnable:
return "RUNNABLE";
@@ -16,6 +16,18 @@
}
}
+std::string kotlin::internal::statesToString(std::initializer_list<ThreadState> states) noexcept {
+ std::string result = "{ ";
+ for (size_t i = 0; i < states.size(); i++) {
+ if (i != 0) {
+ result += ", ";
+ }
+ result += ThreadStateName(data(states)[i]);
+ }
+ result += " }";
+ return result;
+}
+
ALWAYS_INLINE ThreadState kotlin::SwitchThreadState(MemoryState* thread, ThreadState newState, bool reentrant) noexcept {
return SwitchThreadState(thread->GetThreadData(), newState, reentrant);
}
@@ -24,6 +36,10 @@
AssertThreadState(thread->GetThreadData(), expected);
}
+ALWAYS_INLINE void kotlin::AssertThreadState(MemoryState* thread, std::initializer_list<ThreadState> expected) noexcept {
+ AssertThreadState(thread->GetThreadData(), expected);
+}
+
ThreadState kotlin::GetThreadState(MemoryState* thread) noexcept {
return thread->GetThreadData()->state();
}
\ No newline at end of file
diff --git a/kotlin-native/runtime/src/mm/cpp/ThreadState.hpp b/kotlin-native/runtime/src/mm/cpp/ThreadState.hpp
index 6b3c8ee..f519660 100644
--- a/kotlin-native/runtime/src/mm/cpp/ThreadState.hpp
+++ b/kotlin-native/runtime/src/mm/cpp/ThreadState.hpp
@@ -6,10 +6,13 @@
#ifndef RUNTIME_MM_THREAD_STATE_H
#define RUNTIME_MM_THREAD_STATE_H
+#include <ostream>
+
#include <Common.h>
#include <Utils.hpp>
#include "ThreadData.hpp"
+#include "ThreadSuspension.hpp"
namespace kotlin {
@@ -19,17 +22,23 @@
return oldState != newState || reentrant;
}
-const char* stateToString(ThreadState state) noexcept;
+std::string statesToString(std::initializer_list<ThreadState> states) noexcept;
} // namespace internal
+const char* ThreadStateName(ThreadState state) noexcept;
+
+inline std::ostream& operator<<(std::ostream& stream, ThreadState state) {
+ return stream << ThreadStateName(state);
+}
+
// Switches the state of the given thread to `newState` and returns the previous thread state.
ALWAYS_INLINE inline ThreadState SwitchThreadState(mm::ThreadData* threadData, ThreadState newState, bool reentrant = false) noexcept {
auto oldState = threadData->setState(newState);
// TODO(perf): Mesaure the impact of this assert in debug and opt modes.
RuntimeAssert(internal::isStateSwitchAllowed(oldState, newState, reentrant),
"Illegal thread state switch. Old state: %s. New state: %s.",
- internal::stateToString(oldState), internal::stateToString(newState));
+ ThreadStateName(oldState), ThreadStateName(newState));
return oldState;
}
@@ -38,7 +47,14 @@
auto actual = threadData->state();
RuntimeAssert(actual == expected,
"Unexpected thread state. Expected: %s. Actual: %s.",
- internal::stateToString(expected), internal::stateToString(actual));
+ ThreadStateName(expected), ThreadStateName(actual));
+}
+
+ALWAYS_INLINE inline void AssertThreadState(mm::ThreadData* threadData, std::initializer_list<ThreadState> expected) noexcept {
+ auto actual = threadData->state();
+ RuntimeAssert(std::any_of(expected.begin(), expected.end(), [actual](ThreadState expected) { return expected == actual; }),
+ "Unexpected thread state. Expected one of: %s. Actual: %s",
+ internal::statesToString(expected).c_str(), ThreadStateName(actual));
}
} // namespace kotlin
diff --git a/kotlin-native/runtime/src/mm/cpp/ThreadStateTest.cpp b/kotlin-native/runtime/src/mm/cpp/ThreadStateTest.cpp
index c556820..98612df 100644
--- a/kotlin-native/runtime/src/mm/cpp/ThreadStateTest.cpp
+++ b/kotlin-native/runtime/src/mm/cpp/ThreadStateTest.cpp
@@ -41,6 +41,9 @@
//static
testing::MockFunction<int32_t(int32_t)>* ThreadStateTest::globalSomeFunctionMock = nullptr;
+#define EXPECT_NO_DEATH(statement) \
+ do { EXPECT_EXIT({statement; exit(0);}, testing::ExitedWithCode(0), testing::_); } while(false)
+
} // namespace
TEST_F(ThreadStateTest, StateSwitchWithThreadData) {
@@ -148,13 +151,11 @@
});
}
-TEST(ThreadStateDeathTest, IncorrectStateSwitch) {
+TEST(ThreadStateDeathTest, IncorrectStateSwitchWithDifferentFunctions) {
RunInNewThread([](MemoryState* memoryState) {
auto* threadData = memoryState->GetThreadData();
EXPECT_DEATH(SwitchThreadState(memoryState, ThreadState::kRunnable),
"runtime assert: Illegal thread state switch. Old state: RUNNABLE. New state: RUNNABLE");
- EXPECT_DEATH(SwitchThreadState(threadData, ThreadState::kRunnable),
- "runtime assert: Illegal thread state switch. Old state: RUNNABLE. New state: RUNNABLE");
EXPECT_DEATH(Kotlin_mm_switchThreadStateRunnable(),
"runtime assert: Illegal thread state switch. Old state: RUNNABLE. New state: RUNNABLE");
@@ -165,6 +166,24 @@
});
}
+TEST(ThreadStateDeathTest, StateSwitchCorrectness) {
+ mm::ThreadData threadData(pthread_self());
+
+ // Allowed state switches: runnable <-> native
+ threadData.setState(ThreadState::kRunnable);
+ ASSERT_EQ(threadData.state(), ThreadState::kRunnable);
+ EXPECT_DEATH(SwitchThreadState(&threadData, ThreadState::kRunnable),
+ "runtime assert: Illegal thread state switch. Old state: RUNNABLE. New state: RUNNABLE");
+ // Each EXPECT_NO_DEATH is executed in a fork process, so the global state of the test is not affected.
+ EXPECT_NO_DEATH(SwitchThreadState(&threadData, ThreadState::kNative));
+
+ threadData.setState(ThreadState::kNative);
+ ASSERT_EQ(threadData.state(), ThreadState::kNative);
+ EXPECT_NO_DEATH(SwitchThreadState(&threadData, ThreadState::kRunnable));
+ EXPECT_DEATH(SwitchThreadState(&threadData, ThreadState::kNative),
+ "runtime assert: Illegal thread state switch. Old state: NATIVE. New state: NATIVE");
+}
+
TEST(ThreadStateDeathTest, ReentrantStateSwitch) {
RunInNewThread([](MemoryState* memoryState) {
auto* threadData = memoryState->GetThreadData();
diff --git a/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.cpp b/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.cpp
new file mode 100644
index 0000000..424c61c
--- /dev/null
+++ b/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.cpp
@@ -0,0 +1,95 @@
+/*
+ * Copyright 2010-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license
+ * that can be found in the LICENSE file.
+ */
+
+#include "ThreadData.hpp"
+#include "ThreadSuspension.hpp"
+
+#include <condition_variable>
+#include <thread>
+#include <mutex>
+
+namespace {
+
+bool isSuspendedOrNative(kotlin::mm::ThreadData& thread) noexcept {
+ auto& suspensionData = thread.suspensionData();
+ return suspensionData.suspended() || suspensionData.state() == kotlin::ThreadState::kNative;
+}
+
+template<typename F>
+bool allThreads(F predicate) noexcept {
+ auto& threadRegistry = kotlin::mm::ThreadRegistry::Instance();
+ auto* currentThread = threadRegistry.CurrentThreadData();
+ kotlin::mm::ThreadRegistry::Iterable threads = kotlin::mm::ThreadRegistry::Instance().Iter();
+ for (auto& thread : threads) {
+ // Handle if suspension was initiated by the mutator thread.
+ if (&thread == currentThread)
+ continue;
+ if (!predicate(thread)) {
+ return false;
+ }
+ }
+ return true;
+}
+
+void yield() noexcept {
+ std::this_thread::yield();
+}
+
+std::atomic<bool> gSuspensionRequested = false;
+THREAD_LOCAL_VARIABLE bool gSuspensionRequestedByCurrentThread = false;
+std::mutex gSuspensionMutex;
+std::condition_variable gSuspendsionCondVar;
+
+} // namespace
+
+bool kotlin::mm::ThreadSuspensionData::suspendIfRequested() noexcept {
+ if (IsThreadSuspensionRequested()) {
+ std::unique_lock lock(gSuspensionMutex);
+ if (IsThreadSuspensionRequested()) {
+ AutoReset scopedAssign(&suspended_, true);
+ gSuspendsionCondVar.wait(lock, []() { return !IsThreadSuspensionRequested(); });
+ return true;
+ }
+ }
+ return false;
+}
+
+bool kotlin::mm::IsThreadSuspensionRequested() noexcept {
+ // TODO: Consider using a more relaxed memory order.
+ return gSuspensionRequested.load();
+}
+
+bool kotlin::mm::SuspendThreads() noexcept {
+ RuntimeAssert(gSuspensionRequestedByCurrentThread == false, "Current thread already suspended threads.");
+ {
+ std::unique_lock lock(gSuspensionMutex);
+ bool actual = false;
+ gSuspensionRequested.compare_exchange_strong(actual, true);
+ if (actual) {
+ return false;
+ }
+ }
+ gSuspensionRequestedByCurrentThread = true;
+
+ // Spin wating for threads to suspend. Ignore Native threads.
+ while(!allThreads(isSuspendedOrNative)) {
+ yield();
+ }
+
+ return true;
+}
+
+void kotlin::mm::ResumeThreads() noexcept {
+ // From the std::condition_variable docs:
+ // Even if the shared variable is atomic, it must be modified under
+ // the mutex in order to correctly publish the modification to the waiting thread.
+ // https://en.cppreference.com/w/cpp/thread/condition_variable
+ {
+ std::unique_lock lock(gSuspensionMutex);
+ gSuspensionRequested = false;
+ }
+ gSuspensionRequestedByCurrentThread = false;
+ gSuspendsionCondVar.notify_all();
+}
diff --git a/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.hpp b/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.hpp
new file mode 100644
index 0000000..e57584f
--- /dev/null
+++ b/kotlin-native/runtime/src/mm/cpp/ThreadSuspension.hpp
@@ -0,0 +1,60 @@
+/*
+ * Copyright 2010-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license
+ * that can be found in the LICENSE file.
+ */
+
+#ifndef RUNTIME_MM_THREAD_SUSPENSION_UTILS_H
+#define RUNTIME_MM_THREAD_SUSPENSION_UTILS_H
+
+#include <atomic>
+
+#include "Memory.h"
+
+namespace kotlin {
+namespace mm {
+
+class ThreadSuspensionData : private Pinned {
+public:
+ explicit ThreadSuspensionData(ThreadState initialState) noexcept : state_(initialState), suspended_(false) {}
+
+ ~ThreadSuspensionData() = default;
+
+ ThreadState state() noexcept { return state_; }
+
+ ThreadState setState(ThreadState newState) noexcept {
+ ThreadState oldState = state_.exchange(newState);
+ if (oldState == ThreadState::kNative && newState == ThreadState::kRunnable) {
+ suspendIfRequested();
+ }
+ return oldState;
+ }
+
+ bool suspended() noexcept { return suspended_; }
+
+ bool suspendIfRequested() noexcept;
+
+private:
+ std::atomic<ThreadState> state_;
+ std::atomic<bool> suspended_;
+};
+
+bool IsThreadSuspensionRequested() noexcept;
+
+/**
+ * Suspends all threads registered in ThreadRegistry except threads that are in the Native state.
+ * Blocks until all such threads are suspended. Threads that are in the Native state on the moment
+ * of this call will be suspended on exit from the Native state.
+ * Returns false if some other thread has suspended the threads.
+ */
+bool SuspendThreads() noexcept;
+
+/**
+ * Resumes all threads registered in ThreadRegistry that were suspended by the SuspendThreads call.
+ * Does not wait until all such threads are actually resumed.
+ */
+void ResumeThreads() noexcept;
+
+} // namespace mm
+} // namespace kotlin
+
+#endif // RUNTIME_MM_THREAD_SUSPENSION_UTILS_H
diff --git a/kotlin-native/runtime/src/mm/cpp/ThreadSuspensionTest.cpp b/kotlin-native/runtime/src/mm/cpp/ThreadSuspensionTest.cpp
new file mode 100644
index 0000000..4d53f66
--- /dev/null
+++ b/kotlin-native/runtime/src/mm/cpp/ThreadSuspensionTest.cpp
@@ -0,0 +1,205 @@
+/*
+ * Copyright 2010-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license
+ * that can be found in the LICENSE file.
+ */
+
+#include "MemoryPrivate.hpp"
+#include "ThreadSuspension.hpp"
+#include "ThreadState.hpp"
+
+#include <gtest/gtest.h>
+#include <gmock/gmock.h>
+
+#include <thread>
+#include <TestSupport.hpp>
+
+#include <iostream>
+
+using namespace kotlin;
+
+namespace {
+
+#ifdef KONAN_WINDOWS
+constexpr size_t kDefaultIterations = 1000;
+constexpr size_t kDefaultReportingStep = 100;
+#else
+constexpr size_t kDefaultIterations = 10000;
+constexpr size_t kDefaultReportingStep = 1000;
+#endif // #ifdef KONAN_WINDOWS
+
+KStdVector<mm::ThreadData*> collectThreadData() {
+ KStdVector<mm::ThreadData*> result;
+ auto iter = mm::ThreadRegistry::Instance().Iter();
+ for (auto& thread : iter) {
+ result.push_back(&thread);
+ }
+ return result;
+}
+
+template<typename T, typename F>
+KStdVector<T> collectFromThreadData(F extractFunction) {
+ KStdVector<T> result;
+ auto threadData = collectThreadData();
+ std::transform(threadData.begin(), threadData.end(), std::back_inserter(result), extractFunction);
+ return result;
+}
+
+KStdVector<bool> collectSuspended() {
+ return collectFromThreadData<bool>(
+ [](mm::ThreadData* threadData) { return threadData->suspensionData().suspended(); });
+}
+
+void reportProgress(size_t currentIteration, size_t totalIterations) {
+ if (currentIteration % kDefaultReportingStep == 0) {
+ std::cout << "Iteration: " << currentIteration << " of " << totalIterations << std::endl;
+ }
+}
+
+} // namespace
+
+class ThreadSuspensionTest : public ::testing::Test {
+public:
+ ~ThreadSuspensionTest() {
+ canStart = true;
+ shouldStop = true;
+ for (auto& thread : threads) {
+ if (thread.joinable()) thread.join();
+ }
+ }
+
+ static constexpr size_t kThreadCount = kDefaultThreadCount;
+ static constexpr size_t kIterations = kDefaultIterations;
+
+ KStdVector<std::thread> threads;
+ std::array<std::atomic<bool>, kThreadCount> ready{false};
+ std::atomic<bool> canStart{false};
+ std::atomic<bool> shouldStop{false};
+
+ void waitUntilCanStart(size_t threadNumber) {
+ ready[threadNumber] = true;
+ while(!canStart) {
+ std::this_thread::yield();
+ }
+ ready[threadNumber] = false;
+ }
+
+ void waitUntilThreadsAreReady() {
+ canStart = false;
+ while (!std::all_of(ready.begin(), ready.end(), [](bool it) { return it; })) {
+ std::this_thread::yield();
+ }
+ }
+};
+
+TEST_F(ThreadSuspensionTest, SimpleStartStop) {
+ ASSERT_THAT(collectThreadData(), testing::IsEmpty());
+ for (size_t i = 0; i < kThreadCount; i++) {
+ threads.emplace_back([this, i]() {
+ ScopedMemoryInit init;
+ auto& suspensionData = init.memoryState()->GetThreadData()->suspensionData();
+ EXPECT_EQ(mm::IsThreadSuspensionRequested(), false);
+
+ while(!shouldStop) {
+ waitUntilCanStart(i);
+
+ EXPECT_FALSE(suspensionData.suspended());
+ suspensionData.suspendIfRequested();
+ EXPECT_FALSE(suspensionData.suspended());
+ }
+ });
+ }
+ waitUntilThreadsAreReady();
+
+ for (size_t i = 0; i < kIterations; i++) {
+ reportProgress(i, kIterations);
+ canStart = true;
+
+ mm::SuspendThreads();
+ auto suspended = collectSuspended();
+ EXPECT_THAT(suspended, testing::Each(true));
+ EXPECT_EQ(mm::IsThreadSuspensionRequested(), true);
+
+ mm::ResumeThreads();
+
+ // Wait for threads to run and sync for the next iteration
+ waitUntilThreadsAreReady();
+
+ suspended = collectSuspended();
+ EXPECT_THAT(suspended, testing::Each(false));
+ EXPECT_EQ(mm::IsThreadSuspensionRequested(), false);
+ }
+}
+
+
+TEST_F(ThreadSuspensionTest, SwitchStateToNative) {
+ ASSERT_THAT(collectThreadData(), testing::IsEmpty());
+
+ for (size_t i = 0; i < kThreadCount; i++) {
+ threads.emplace_back([this, i]() {
+ ScopedMemoryInit init;
+ auto* threadData = init.memoryState()->GetThreadData();
+ EXPECT_EQ(mm::IsThreadSuspensionRequested(), false);
+
+ while(!shouldStop) {
+ waitUntilCanStart(i);
+
+ EXPECT_EQ(threadData->state(), ThreadState::kRunnable);
+ SwitchThreadState(threadData, ThreadState::kNative);
+ EXPECT_EQ(threadData->state(), ThreadState::kNative);
+ SwitchThreadState(threadData, ThreadState::kRunnable);
+ EXPECT_EQ(threadData->state(), ThreadState::kRunnable);
+ }
+ });
+ }
+ waitUntilThreadsAreReady();
+
+ for (size_t i = 0; i < kIterations; i++) {
+ reportProgress(i, kIterations);
+ canStart = true;
+
+ mm::SuspendThreads();
+ EXPECT_EQ(mm::IsThreadSuspensionRequested(), true);
+
+ mm::ResumeThreads();
+ EXPECT_EQ(mm::IsThreadSuspensionRequested(), false);
+
+ // Sync for the next iteration.
+ waitUntilThreadsAreReady();
+ }
+}
+
+TEST_F(ThreadSuspensionTest, ConcurrentSuspend) {
+ ASSERT_THAT(collectThreadData(), testing::IsEmpty());
+ std::atomic<size_t> successCount = 0;
+
+ for (size_t i = 0; i < kThreadCount; i++) {
+ threads.emplace_back([this, i, &successCount]() {
+ ScopedMemoryInit init;
+ auto* currentThreadData = init.memoryState()->GetThreadData();
+ EXPECT_EQ(mm::IsThreadSuspensionRequested(), false);
+
+ // Sync with other threads.
+ ready[i] = true;
+ waitUntilThreadsAreReady();
+
+ bool success = mm::SuspendThreads();
+ if (success) {
+ successCount++;
+ auto allThreadData = collectThreadData();
+ auto isCurrentOrSuspended = [currentThreadData](mm::ThreadData* data) {
+ return data == currentThreadData || data->suspensionData().suspended();
+ };
+ EXPECT_THAT(allThreadData, testing::Each(testing::Truly(isCurrentOrSuspended)));
+ EXPECT_FALSE(currentThreadData->suspensionData().suspended());
+ mm::ResumeThreads();
+ } else {
+ EXPECT_TRUE(mm::IsThreadSuspensionRequested());
+ currentThreadData->suspensionData().suspendIfRequested();
+ }
+ });
+ }
+ for (auto& thread : threads) {
+ thread.join();
+ }
+ EXPECT_EQ(successCount, 1u);
+}