Parallelize stdout and clean up PersistentWorker (#501)
* Parallelize stdout and refactor PersistentWorker
* Serialize writing to stdout
* Dont use flow
diff --git a/src/main/kotlin/io/bazel/worker/PersistentWorker.kt b/src/main/kotlin/io/bazel/worker/PersistentWorker.kt
index 2b99b8b..10fdad8 100644
--- a/src/main/kotlin/io/bazel/worker/PersistentWorker.kt
+++ b/src/main/kotlin/io/bazel/worker/PersistentWorker.kt
@@ -29,7 +29,17 @@
import kotlinx.coroutines.flow.buffer
import kotlinx.coroutines.flow.collect
import kotlinx.coroutines.flow.map
+import kotlinx.coroutines.flow.flowOn
+import kotlinx.coroutines.flow.Flow
+import kotlinx.coroutines.Job
+import kotlinx.coroutines.channels.Channel
+import kotlinx.coroutines.channels.Channel.Factory.UNLIMITED
+import kotlinx.coroutines.flow.consumeAsFlow
+import kotlinx.coroutines.flow.onCompletion
+import kotlinx.coroutines.launch
import kotlinx.coroutines.runBlocking
+import kotlinx.coroutines.withContext
+import java.io.PrintStream
import java.nio.charset.StandardCharsets.UTF_8
import java.util.concurrent.Executors
import kotlin.coroutines.CoroutineContext
@@ -50,81 +60,63 @@
constructor() : this(Dispatchers.IO, IO.Companion::capture)
- /**
- * ThreadAwareDispatchers provides an ability to separate thread blocking operations from coroutines..
- *
- * Coroutines interleave actions over a pool of threads. When an action blocks it stands a chance
- * of producing a deadlock. We sidestep this by providing a separate dispatcher to contain
- * blocking operations, like reading from a stream. Inelegant, and a bit of a sledgehammer, but
- * safe for the moment.
- */
- private class BlockableDispatcher(
- private val unblockedContext: CoroutineContext,
- private val blockingContext: ExecutorCoroutineDispatcher,
- scope: CoroutineScope
- ) : CoroutineScope by scope {
- companion object {
- fun <T> runIn(
- owningContext: CoroutineContext,
- exec: suspend BlockableDispatcher.() -> T
- ) =
- Executors.newCachedThreadPool().asCoroutineDispatcher().use { dispatcher ->
- runBlocking(owningContext) { BlockableDispatcher(owningContext, dispatcher, this).exec() }
- }
- }
-
- fun <T> blockable(action: () -> T): T {
- return runBlocking(blockingContext) {
- return@runBlocking action()
- }
- }
- }
@ExperimentalCoroutinesApi
override fun start(execute: Work) = WorkerContext.run {
+ //Use channel to serialize writing output
+ val writeChannel = Channel<WorkerProtocol.WorkResponse>(UNLIMITED)
captureIO().use { io ->
- BlockableDispatcher.runIn(coroutineContext) {
- blockable {
- generateSequence { WorkRequest.parseDelimitedFrom(io.input) }
- }.asFlow()
- .map { request ->
- info { "received req: ${request.requestId}" }
- async {
- doTask("request ${request.requestId}") { ctx ->
- request.argumentsList.run {
- execute(ctx, toList())
+ runBlocking {
+ //Parent coroutine to track all of children and close channel on completion
+ launch(Dispatchers.Default) {
+ generateSequence { WorkRequest.parseDelimitedFrom(io.input) }
+ .forEach { request ->
+ launch {
+ compileWork(request, io, writeChannel, execute)
}
- }.let { result ->
- info { "task result ${result.status}" }
- WorkerProtocol.WorkResponse.newBuilder().apply {
- output =
- listOf(
- result.log.out.toString(),
- io.captured.toByteArray().toString(UTF_8)
- ).filter { it.isNotBlank() }.joinToString("\n")
- exitCode = result.status.exit
- requestId = request.requestId
- }.build()
}
- }
- }
- .buffer()
- .map { deferred ->
- deferred.await()
- }
- .collect { response ->
- blockable {
- info {
- response.toString()
- }
- response.writeDelimitedTo(io.output)
- io.output.flush()
- }
- }
+ }.invokeOnCompletion { writeChannel.close() }
+
+ writeChannel.consumeAsFlow()
+ .collect { response -> writeOutput(response, io.output) }
}
+
io.output.close()
info { "stopped worker" }
}
return@run 0
}
+
+ private suspend fun WorkerContext.compileWork(
+ request: WorkRequest,
+ io: IO,
+ chan: Channel<WorkerProtocol.WorkResponse>,
+ execute: Work
+ ) = withContext(Dispatchers.Default) {
+ val result = doTask("request ${request.requestId}") { ctx ->
+ request.argumentsList.run {
+ execute(ctx, toList())
+ }
+ }
+ info { "task result ${result.status}" }
+ val response = WorkerProtocol.WorkResponse.newBuilder().apply {
+ output = listOf(
+ result.log.out.toString(),
+ io.captured.toByteArray().toString(UTF_8)
+ ).filter { it.isNotBlank() }.joinToString("\n")
+ exitCode = result.status.exit
+ requestId = request.requestId
+ }.build()
+ info {
+ response.toString()
+ }
+ chan.send(response)
+ }
+
+ private suspend fun writeOutput(response: WorkerProtocol.WorkResponse, output: PrintStream) =
+ withContext(Dispatchers.IO) {
+ response.writeDelimitedTo(output)
+ output.flush()
+ }
+
}