[WASM] POC Hotswap
diff --git a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/WasmBackendContext.kt b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/WasmBackendContext.kt index 84affb9..71dfe85 100644 --- a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/WasmBackendContext.kt +++ b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/WasmBackendContext.kt
@@ -109,6 +109,7 @@ } val fieldInitFunction = createInitFunction("fieldInit") + val hotSwapFieldInitFunction = createInitFunction("hotSwapFieldInit") val mainCallsWrapperFunction = createInitFunction("mainCallsWrapper") override val sharedVariablesManager =
diff --git a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/WasmSymbols.kt b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/WasmSymbols.kt index 062f08c..88fedbd 100644 --- a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/WasmSymbols.kt +++ b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/WasmSymbols.kt
@@ -183,6 +183,7 @@ val refCastNull = getInternalFunction("wasm_ref_cast_deprecated") val wasmArrayCopy = getInternalFunction("wasm_array_copy") val wasmArrayNewData0 = getInternalFunction("array_new_data0") + val initiateHotReload = getInternalFunction("initiateHotReload") val intToLong = getInternalFunction("wasm_i64_extend_i32_s")
diff --git a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/BodyGenerator.kt b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/BodyGenerator.kt index 7f7e9f2..a8c0e15 100644 --- a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/BodyGenerator.kt +++ b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/BodyGenerator.kt
@@ -35,6 +35,7 @@ val functionContext: WasmFunctionCodegenContext, private val hierarchyDisjointUnions: DisjointUnions<IrClassSymbol>, private val isGetUnitFunction: Boolean, + private val isUserDefinedFunction: Boolean, ) : IrElementVisitorVoid { val body: WasmExpressionBuilder = functionContext.bodyGen @@ -193,8 +194,17 @@ generateInstanceFieldAccess(field, location) } } else { - body.buildGetGlobal(context.referenceGlobalField(field.symbol), location) - body.commentPreviousInstr { "type: ${field.type.render()}" } + if (isUserDefinedFunction) { + body.buildConstI32Symbol(context.referenceHotswapFieldGetterTableIndex(field.symbol), location) + body.buildCallIndirect( + context.referenceHotswapFieldGetterFunctionType(field.symbol), + WasmSymbol(1), + location + ) + } else { + body.buildGetGlobal(context.referenceGlobalField(field.symbol), location) + body.commentPreviousInstr { "type: ${field.type.render()}" } + } } } @@ -237,8 +247,17 @@ body.commentPreviousInstr { "name: ${field.name}, type: ${field.type.render()}" } } else { generateExpression(expression.value) - body.buildSetGlobal(context.referenceGlobalField(expression.symbol), location) - body.commentPreviousInstr { "type: ${field.type.render()}" } + if (isUserDefinedFunction) { + body.buildConstI32Symbol(context.referenceHotswapFieldSetterTableIndex(field.symbol), location) + body.buildCallIndirect( + context.referenceHotswapFieldSetterFunctionType(field.symbol), + WasmSymbol(2), + location + ) + } else { + body.buildSetGlobal(context.referenceGlobalField(expression.symbol), location) + body.commentPreviousInstr { "type: ${field.type.render()}" } + } } body.buildGetUnit() @@ -594,7 +613,7 @@ } wasmSymbols.unsafeGetScratchRawMemory -> { - + body.buildConstI32Symbol(context.scratchMemAddr, location) } @@ -618,6 +637,19 @@ body.buildInstr(WasmOp.ARRAY_NEW_DATA, location, arrayGcType, WasmImmediate.DataIdx(0)) } + wasmSymbols.initiateHotReload -> { + val qqq = SourceLocation.NoLocation("hot swap") + body.buildConstI32(8372, qqq) + body.buildConstI32(0, qqq) + body.buildConstI32(0, qqq) + body.buildInstr( + WasmOp.TABLE_COPY, + qqq, + WasmImmediate.TableIdx(0), + WasmImmediate.TableIdx(1) + ) + } + else -> { return false }
diff --git a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/ClassInfo.kt b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/ClassInfo.kt index cf89dc7..1b59c03 100644 --- a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/ClassInfo.kt +++ b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/ClassInfo.kt
@@ -11,10 +11,7 @@ import org.jetbrains.kotlin.ir.IrBuiltIns import org.jetbrains.kotlin.ir.backend.js.utils.eraseGenerics import org.jetbrains.kotlin.ir.backend.js.utils.realOverrideTarget -import org.jetbrains.kotlin.ir.declarations.IrClass -import org.jetbrains.kotlin.ir.declarations.IrDeclaration -import org.jetbrains.kotlin.ir.declarations.IrField -import org.jetbrains.kotlin.ir.declarations.IrSimpleFunction +import org.jetbrains.kotlin.ir.declarations.* import org.jetbrains.kotlin.ir.types.IrType import org.jetbrains.kotlin.ir.types.classifierOrFail import org.jetbrains.kotlin.ir.util.fqNameWhenAvailable @@ -42,7 +39,7 @@ } } -fun IrSimpleFunction.wasmSignature(irBuiltIns: IrBuiltIns): WasmSignature = +fun IrFunction.wasmSignature(irBuiltIns: IrBuiltIns): WasmSignature = WasmSignature( name, extensionReceiverParameter?.type?.eraseGenerics(irBuiltIns),
diff --git a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/DeclarationGenerator.kt b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/DeclarationGenerator.kt index 5083bf9..8085149 100644 --- a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/DeclarationGenerator.kt +++ b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/DeclarationGenerator.kt
@@ -36,6 +36,9 @@ private val hierarchyDisjointUnions: DisjointUnions<IrClassSymbol>, ) : IrElementVisitorVoid { + //TODO + var isHotSwapEnabledModule = false + // Shortcuts private val backendContext: WasmBackendContext = context.backendContext private val irBuiltIns: IrBuiltIns = backendContext.irBuiltIns @@ -57,7 +60,7 @@ private fun jsCodeName(declaration: IrFunction): String { require(declaration is IrSimpleFunction) val fqName = declaration.fqNameWhenAvailable!!.asString() - val hashCode = declaration.wasmSignature(irBuiltIns).hashCode() + declaration.file.path.hashCode() + val hashCode = declaration.wasmSignature(irBuiltIns).toString().hashCode() + declaration.file.path.hashCode() return "${fqName}_$hashCode" } @@ -84,7 +87,7 @@ // check(declaration.isExternal) { "Non-external fun with @JsFun ${declaration.fqNameWhenAvailable}"} val jsCodeName = jsCodeName(declaration) context.addJsFun(jsCodeName, jsCode) - WasmImportDescriptor("js_code", jsCodeName(declaration)) + WasmImportDescriptor("js_code", jsCodeName) } else -> { null @@ -148,7 +151,8 @@ context = context, functionContext = functionCodegenContext, hierarchyDisjointUnions = hierarchyDisjointUnions, - isGetUnitFunction = declaration == unitGetInstanceFunction + isGetUnitFunction = declaration == unitGetInstanceFunction, + isUserDefinedFunction = isHotSwapEnabledModule, ) if (declaration is IrConstructor) { @@ -172,11 +176,29 @@ exprGen.buildUnreachableForVerifier() } - context.defineFunction(declaration.symbol, function) + if (declaration == backendContext.hotSwapFieldInitFunction) { + context.addExport(WasmExport.Function("__hotSwapInit", function)) + } + + val functionToDefine: WasmFunction + if (isHotSwapEnabledModule && declaration.kotlinFqName.parentOrNull() != context.backendContext.kotlinWasmInternalPackageFqn) { + functionToDefine = wrapFunctionToHotswapBridge( + declaration, + function, + functionTypeSymbol, + irParameters, + watName + ) + } else { + functionToDefine = function + } + + context.defineFunction(declaration.symbol, functionToDefine) val initPriority = when (declaration) { - backendContext.fieldInitFunction -> "0" - backendContext.mainCallsWrapperFunction -> "1" + backendContext.hotSwapFieldInitFunction -> "0" + backendContext.fieldInitFunction -> "1" + backendContext.mainCallsWrapperFunction -> "2" else -> null } if (initPriority != null) @@ -192,6 +214,50 @@ } } + private fun wrapFunctionToHotswapBridge( + declaration: IrFunction, + function: WasmFunction, + functionTypeSymbol: WasmSymbol<WasmFunctionType>, + irParameters: List<IrValueParameter>, + watName: String + ): WasmFunction { + context.defineHotswapFunction(declaration.symbol, function) + val hotswapBridge = WasmFunction.Defined(watName + "_bridge", functionTypeSymbol) + + val hotswapBridgeGenerationContext = WasmFunctionCodegenContext( + declaration, + hotswapBridge, + backendContext, + context + ) + + for (irParameter in irParameters) { + hotswapBridgeGenerationContext.defineLocal(irParameter.symbol) + } + + val exprGen = hotswapBridgeGenerationContext.bodyGen + + val bridgeLocation = SourceLocation.NoLocation("hotswap bridge") + + for (irParameter in irParameters) { + exprGen.buildGetLocal( + hotswapBridgeGenerationContext.referenceLocal(irParameter.symbol), + bridgeLocation + ) + } + + exprGen.buildConstI32Symbol(context.referenceHotswapTableIndex(declaration.symbol), bridgeLocation) + + exprGen.buildCallIndirect( + functionTypeSymbol, + WasmSymbol(0), + bridgeLocation + ) + + exprGen.buildInstr(WasmOp.RETURN, bridgeLocation) + return hotswapBridge + } + private fun createDeclarationByInterface(iFace: IrClassSymbol) { if (context.isAlreadyDefinedClassITableGcType(iFace)) return if (iFace !in hierarchyDisjointUnions) return @@ -467,14 +533,47 @@ generateDefaultInitializerForType(wasmType, wasmExpressionGenerator) } + val fieldName = declaration.fqNameWhenAvailable.toString() + val global = WasmGlobal( - name = declaration.fqNameWhenAvailable.toString(), + name = fieldName, type = wasmType, isMutable = true, init = initBody ) context.defineGlobalField(declaration.symbol, global) + + /// GETTER + val bridgeLocation = SourceLocation.NoLocation("field hotswap getter bridge") + val getterName = "${fieldName}_getter_bridge" + + val wasmGetterFunctionType = WasmFunctionType(parameterTypes = emptyList(), resultTypes = listOf(wasmType)) + val getterBridgeTypeReference = context.referenceHotswapFieldGetterFunctionType(declaration.symbol) + context.defineHotswapFieldGetterFunctionType(declaration.symbol, wasmGetterFunctionType) + + val getterBridgeFunction = WasmFunction.Defined(getterName, getterBridgeTypeReference) + with(WasmIrExpressionBuilder(getterBridgeFunction.instructions)) { + buildGetGlobal(context.referenceGlobalField(declaration.symbol), bridgeLocation) + buildInstr(WasmOp.RETURN, bridgeLocation) + } + context.defineHotswapFieldGetter(declaration.symbol, getterBridgeFunction) + + /// SETTER + val setterName = "${fieldName}_setter_bridge" + + val wasmSetterFunctionType = WasmFunctionType(parameterTypes = listOf(wasmType), resultTypes = emptyList()) + val setterBridgeTypeReference = context.referenceHotswapFieldSetterFunctionType(declaration.symbol) + context.defineHotswapFieldSetterFunctionType(declaration.symbol, wasmSetterFunctionType) + + val setterLocal = WasmLocal(0, "param0", wasmType, true) + val setterBridgeFunction = WasmFunction.Defined(setterName, setterBridgeTypeReference, mutableListOf(setterLocal)) + with(WasmIrExpressionBuilder(setterBridgeFunction.instructions)) { + buildGetLocal(setterLocal, bridgeLocation) + buildSetGlobal(context.referenceGlobalField(declaration.symbol), bridgeLocation) + buildInstr(WasmOp.RETURN, bridgeLocation) + } + context.defineHotswapFieldSetter(declaration.symbol, setterBridgeFunction) } }
diff --git a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmCompiledModuleFragment.kt b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmCompiledModuleFragment.kt index ee8f869..fd8c9fb 100644 --- a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmCompiledModuleFragment.kt +++ b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmCompiledModuleFragment.kt
@@ -6,25 +6,41 @@ package org.jetbrains.kotlin.backend.wasm.ir2wasm import org.jetbrains.kotlin.ir.IrBuiltIns -import org.jetbrains.kotlin.ir.declarations.IrDeclarationWithName -import org.jetbrains.kotlin.ir.declarations.IrExternalPackageFragment +import org.jetbrains.kotlin.ir.declarations.* import org.jetbrains.kotlin.ir.symbols.* import org.jetbrains.kotlin.ir.util.fqNameWhenAvailable import org.jetbrains.kotlin.ir.util.getPackageFragment +import org.jetbrains.kotlin.ir.util.kotlinFqName import org.jetbrains.kotlin.wasm.ir.* import org.jetbrains.kotlin.wasm.ir.source.location.SourceLocation class WasmCompiledModuleFragment(val irBuiltIns: IrBuiltIns) { val functions = ReferencableAndDefinable<IrFunctionSymbol, WasmFunction>() + val hotswapFunctions = + ReferencableAndDefinable<IrFunctionSymbol, WasmFunction>() + val hotswapFunctionIndexes = + ReferencableElements<IrFunctionSymbol, Int>() val globalFields = ReferencableAndDefinable<IrFieldSymbol, WasmGlobal>() + val hotswapFieldGetter = + ReferencableAndDefinable<IrFieldSymbol, WasmFunction>() + val hotswapFieldSetter = + ReferencableAndDefinable<IrFieldSymbol, WasmFunction>() + val hotswapFieldGetterIndexes = + ReferencableElements<IrFieldSymbol, Int>() + val hotswapFieldSetterIndexes = + ReferencableElements<IrFieldSymbol, Int>() val globalVTables = ReferencableAndDefinable<IrClassSymbol, WasmGlobal>() val globalClassITables = ReferencableAndDefinable<IrClassSymbol, WasmGlobal>() val functionTypes = ReferencableAndDefinable<IrFunctionSymbol, WasmFunctionType>() + val hotSwapFieldGetterBridgesFunctionTypes = + ReferencableAndDefinable<IrFieldSymbol, WasmFunctionType>() + val hotSwapFieldSetterBridgesFunctionTypes = + ReferencableAndDefinable<IrFieldSymbol, WasmFunctionType>() val gcTypes = ReferencableAndDefinable<IrClassSymbol, WasmTypeDeclaration>() val vTableGcTypes = @@ -101,6 +117,9 @@ fun linkWasmCompiledFragments(): WasmModule { bind(functions.unbound, functions.defined) + bind(hotswapFunctions.unbound, hotswapFunctions.defined) + bind(hotswapFieldGetter.unbound, hotswapFieldGetter.defined) + bind(hotswapFieldSetter.unbound, hotswapFieldSetter.defined) bind(globalFields.unbound, globalFields.defined) bind(globalVTables.unbound, globalVTables.defined) bind(gcTypes.unbound, gcTypes.defined) @@ -109,9 +128,12 @@ bind(classITableInterfaceSlot.unbound, classITableInterfaceSlot.defined) bind(globalClassITables.unbound, globalClassITables.defined) + val allFunctionTypeElements = + functionTypes.elements + hotSwapFieldGetterBridgesFunctionTypes.elements + hotSwapFieldSetterBridgesFunctionTypes.elements + // Associate function types to a single canonical function type val canonicalFunctionTypes = - functionTypes.elements.associateWithTo(LinkedHashMap()) { it } + allFunctionTypeElements.associateWithTo(LinkedHashMap()) { it } functionTypes.unbound.forEach { (irSymbol, wasmSymbol) -> if (irSymbol !in functionTypes.defined) @@ -119,6 +141,141 @@ wasmSymbol.bind(canonicalFunctionTypes.getValue(functionTypes.defined.getValue(irSymbol))) } + hotSwapFieldGetterBridgesFunctionTypes.unbound.forEach { (irSymbol, wasmSymbol) -> + if (irSymbol !in hotSwapFieldGetterBridgesFunctionTypes.defined) + error("Can't link symbol ${irSymbolDebugDump(irSymbol)}") + wasmSymbol.bind(canonicalFunctionTypes.getValue(hotSwapFieldGetterBridgesFunctionTypes.defined.getValue(irSymbol))) + } + + hotSwapFieldSetterBridgesFunctionTypes.unbound.forEach { (irSymbol, wasmSymbol) -> + if (irSymbol !in hotSwapFieldSetterBridgesFunctionTypes.defined) + error("Can't link symbol ${irSymbolDebugDump(irSymbol)}") + wasmSymbol.bind(canonicalFunctionTypes.getValue(hotSwapFieldSetterBridgesFunctionTypes.defined.getValue(irSymbol))) + } + + val oldSwapTableMap: MutableMap<String, Int> = mutableMapOf() + val aaa = listOf( + "<get-a> [(non-virtual) <get-a>() -> kotlin.Int][DELEGATED_PROPERTY_ACCESSOR]", + "<set-count> [(non-virtual) <set-count>(kotlin.Int) -> kotlin.Unit][DEFAULT_PROPERTY_ACCESSOR]", + "<get-count> [(non-virtual) <get-count>() -> kotlin.Int][DEFAULT_PROPERTY_ACCESSOR]", + "<get-q> [(non-virtual) <get-q>() -> kotlin.Int][DEFAULT_PROPERTY_ACCESSOR]", + "externLol__externalAdapter [(non-virtual) externLol__externalAdapter() -> kotlin.String][DEFINED]", + "box [(non-virtual) box() -> kotlin.String][DEFINED]", + "box__JsExportAdapter [(non-virtual) box__JsExportAdapter() -> kotlin.wasm.internal.ExternalInterfaceType?][DEFINED]", + "appendElement [(non-virtual) (er: org.w3c.dom.HTMLElement) appendElement(kotlin.String) -> org.w3c.dom.Element][DEFINED]", + "update [(non-virtual) update() -> kotlin.Unit][DEFINED]", + "<get-a>\$ref.<init> [(non-virtual) <init>() -> <root>.<get-a>\$ref][GENERATED_MEMBER_IN_CALLABLE_REFERENCE]", + "<get-a>\$ref.invoke [invoke() -> kotlin.Int][DEFINED]", + "<get-a>\$ref.invoke [invoke() -> kotlin.Any?][BRIDGE]", + "<get-a>\$ref.<get-name> [<get-name>() -> kotlin.String][DEFINED]", + "a\$delegate\$lambda.<init> [(non-virtual) <init>() -> <root>.a\$delegate\$lambda][GENERATED_MEMBER_IN_CALLABLE_REFERENCE]", + "a\$delegate\$lambda.invoke [invoke() -> kotlin.Int][DEFINED]", + "a\$delegate\$lambda.invoke [invoke() -> kotlin.Any?][BRIDGE]", + "box\$lambda\$lambda\$lambda.<init> [(non-virtual) <init>(org.w3c.dom.Element) -> <root>.box\$lambda\$lambda\$lambda][GENERATED_MEMBER_IN_CALLABLE_REFERENCE]", + "box\$lambda\$lambda\$lambda.invoke [invoke(org.w3c.dom.events.MouseEvent) -> kotlin.Nothing?][DEFINED]", + "box\$lambda\$lambda\$lambda.invoke [invoke(kotlin.Any?) -> kotlin.Any?][BRIDGE]", + "box\$lambda\$lambda.<init> [(non-virtual) <init>() -> <root>.box\$lambda\$lambda][GENERATED_MEMBER_IN_CALLABLE_REFERENCE]", + "box\$lambda\$lambda.invoke [invoke(org.w3c.dom.Element) -> kotlin.Unit][DEFINED]", + "box\$lambda\$lambda.invoke [invoke(kotlin.Any?) -> kotlin.Any?][BRIDGE]", + "box\$lambda\$lambda.<init> [(non-virtual) <init>(org.w3c.dom.Element) -> <root>.box\$lambda\$lambda][GENERATED_MEMBER_IN_CALLABLE_REFERENCE]", + "box\$lambda\$lambda.invoke [invoke(org.w3c.dom.Element) -> kotlin.Unit][DEFINED]~1", + "box\$lambda\$lambda.invoke [invoke(kotlin.Any?) -> kotlin.Any?][BRIDGE]~1", + "box\$lambda.<init> [(non-virtual) <init>() -> <root>.box\$lambda][GENERATED_MEMBER_IN_CALLABLE_REFERENCE]", + "box\$lambda.invoke [invoke(org.w3c.dom.Element) -> kotlin.Unit][DEFINED]", + "box\$lambda.invoke [invoke(kotlin.Any?) -> kotlin.Any?][BRIDGE]", + "appendElement\$lambda.<init> [(non-virtual) <init>() -> <root>.appendElement\$lambda][GENERATED_MEMBER_IN_CALLABLE_REFERENCE]", + "appendElement\$lambda.invoke [invoke(org.w3c.dom.Element) -> kotlin.Unit][DEFINED]", + "appendElement\$lambda.invoke [invoke(kotlin.Any?) -> kotlin.Any?][BRIDGE]", + "<init properties surrogatePair.kt> [(non-virtual) <init properties surrogatePair.kt>() -> kotlin.Unit][SYNTHESIZED_DECLARATION]", + "kotlin.wasm.internal.\$closureBox\$.<init> [(non-virtual) <init>(kotlin.Any?) -> kotlin.wasm.internal.\$closureBox\$][JS_CLOSURE_BOX_CLASS_DECLARATION]" + ) +// val aaa = emptyList<String>() + aaa.forEachIndexed { index, s -> oldSwapTableMap[s] = index } + + val importDescriptor = WasmImportDescriptor("hotswap_import", "hotswap_replacement_table") + .takeIf { oldSwapTableMap.isNotEmpty() } + + val hotSwapTable = WasmTable( + WasmLimits(0U, null), + WasmFuncRef, + importDescriptor + ) + exports += WasmExport.Table("hotswap_table", hotSwapTable) + + val hotSwapGettersTable = WasmTable( + WasmLimits(hotswapFieldGetter.defined.entries.size.toUInt(), hotswapFieldGetter.defined.entries.size.toUInt()), + WasmFuncRef + ) + + val hotSwapSettersTable = WasmTable( + WasmLimits(hotswapFieldSetter.defined.entries.size.toUInt(), hotswapFieldSetter.defined.entries.size.toUInt()), + WasmFuncRef + ) + + val newSwapTableMap = mutableMapOf<String, Int>() + val newSwapTableReplaceMap = mutableMapOf<Int, WasmFunction>() + + val functionIds = mutableSetOf<String>() + fun IrFunctionSymbol.toId(): String { + val signature = owner.wasmSignature(irBuiltIns).toString() + val fqName = owner.kotlinFqName.asString() + val origin = owner.origin.toString() + val functionId = "$fqName $signature[$origin]" + var functionIdWithIndex = functionId + var functionIdIndex = 0 + while (!functionIds.add(functionIdWithIndex)) { + functionIdIndex++ + functionIdWithIndex = "$functionId~$functionIdIndex" + } + return functionIdWithIndex + } + + var newFunctionsCount = oldSwapTableMap.size + for (entry in hotswapFunctions.defined.entries) { + val functionId = entry.key.toId() + val index = oldSwapTableMap[functionId] ?: newFunctionsCount++ + hotswapFunctionIndexes.reference(entry.key).bind(index) + newSwapTableMap[functionId] = index + newSwapTableReplaceMap[index] = entry.value + } + + val rewriteFunctionTableType = WasmFunctionType(emptyList(), emptyList()) + val rewriteFunctionTable = WasmFunction.Defined("__rewriteFunctionTable", WasmSymbol(rewriteFunctionTableType)) + val hotSwapLocation = SourceLocation.NoLocation("Generated service code") + with(WasmIrExpressionBuilder(rewriteFunctionTable.instructions)) { + buildRefNull(WasmFuncRef.getHeapType(), hotSwapLocation) + buildConstI32(newFunctionsCount - oldSwapTableMap.size, hotSwapLocation) + buildInstr(WasmOp.TABLE_GROW, hotSwapLocation, WasmImmediate.TableIdx(0)) + buildDrop(hotSwapLocation) + for (entry in newSwapTableReplaceMap) { + buildConstI32(entry.key, hotSwapLocation) + buildInstr(WasmOp.REF_FUNC, hotSwapLocation, WasmImmediate.FuncIdx(entry.value)) + buildInstr(WasmOp.TABLE_SET, hotSwapLocation, WasmImmediate.TableIdx(0)) + } + } + exports += WasmExport.Function("__rewriteFunctionTable", rewriteFunctionTable) + + hotswapFieldGetter.defined.entries.forEachIndexed { index, entry -> hotswapFieldGetterIndexes.reference(entry.key).bind(index) } + hotswapFieldSetter.defined.entries.forEachIndexed { index, entry -> hotswapFieldSetterIndexes.reference(entry.key).bind(index) } + + val hotSwapElement = WasmElement( + WasmFuncRef, + hotswapFunctions.defined.entries.map { WasmTable.Value.Function(it.value) }, + WasmElement.Mode.Declarative + ) + + val hotSwapGettersElement = WasmElement( + WasmFuncRef, + hotswapFieldGetter.defined.entries.map { WasmTable.Value.Function(it.value) }, + WasmElement.Mode.Active(hotSwapGettersTable, listOf(WasmInstrWithoutLocation(WasmOp.I32_CONST, listOf(WasmImmediate.ConstI32(0))))) + ) + + val hotSwapSettersElement = WasmElement( + WasmFuncRef, + hotswapFieldSetter.defined.entries.map { WasmTable.Value.Function(it.value) }, + WasmElement.Mode.Active(hotSwapSettersTable, listOf(WasmInstrWithoutLocation(WasmOp.I32_CONST, listOf(WasmImmediate.ConstI32(0))))) + ) + val klassIds = mutableMapOf<IrClassSymbol, Int>() var currentDataSectionAddress = 0 for (typeInfoElement in typeInfo.elements) { @@ -165,6 +322,7 @@ val masterInitFunctionType = WasmFunctionType(emptyList(), emptyList()) val masterInitFunction = WasmFunction.Defined("__init", WasmSymbol(masterInitFunctionType)) with(WasmIrExpressionBuilder(masterInitFunction.instructions)) { + buildCall(WasmSymbol(rewriteFunctionTable), SourceLocation.NoLocation("Generated service code")) initFunctions.sortedBy { it.priority }.forEach { buildCall(WasmSymbol(it.function), SourceLocation.NoLocation("Generated service code")) } @@ -179,6 +337,9 @@ // Export name "memory" is a WASI ABI convention. exports += WasmExport.Memory("memory", memory) + exports += WasmExport.Table("hotswap_getters_table", hotSwapGettersTable) + exports += WasmExport.Table("hotswap_setters_table", hotSwapSettersTable) + val importedFunctions = functions.elements.filterIsInstance<WasmFunction.Imported>() fun wasmTypeDeclarationOrderKey(declaration: WasmTypeDeclaration): Int { @@ -202,7 +363,7 @@ globals.addAll(globalVTables.elements) globals.addAll(globalClassITables.elements.distinct()) - val allFunctionTypes = canonicalFunctionTypes.values.toList() + tagFuncType + masterInitFunctionType + val allFunctionTypes = canonicalFunctionTypes.values.toList() + tagFuncType + masterInitFunctionType + rewriteFunctionTableType // Partition out function types that can't be recursive, // we don't need to put them into a rec group @@ -211,18 +372,25 @@ allFunctionTypes.partition { it.referencesTypeDeclarations() } recGroupTypes.addAll(potentiallyRecursiveFunctionTypes) + + val hotSwapReplacementTableToImport = hotSwapTable.takeIf { oldSwapTableMap.isNotEmpty() } + if (hotSwapReplacementTableToImport != null) { + jsModuleImports.add("hotswap_import") + } + val module = WasmModule( functionTypes = nonRecursiveFunctionTypes, recGroupTypes = recGroupTypes, - importsInOrder = importedFunctions, + importsInOrder = importedFunctions + listOfNotNull(hotSwapReplacementTableToImport), importedFunctions = importedFunctions, - definedFunctions = functions.elements.filterIsInstance<WasmFunction.Defined>() + masterInitFunction, - tables = emptyList(), + importedTables = listOfNotNull(hotSwapReplacementTableToImport), + definedFunctions = (functions.elements + hotswapFunctions.elements + hotswapFieldGetter.elements + hotswapFieldSetter.elements).filterIsInstance<WasmFunction.Defined>() + masterInitFunction + rewriteFunctionTable, + tables = listOfNotNull(hotSwapTable.takeIf { oldSwapTableMap.isEmpty() }, hotSwapGettersTable, hotSwapSettersTable), memories = listOf(memory), globals = globals, exports = exports, startFunction = null, // Module is initialized via export call - elements = emptyList(), + elements = listOf(hotSwapElement, hotSwapGettersElement, hotSwapSettersElement), data = data, dataCount = true, tags = listOf(tag)
diff --git a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmModuleCodegenContext.kt b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmModuleCodegenContext.kt index 8bc96c8..8aa9d91 100644 --- a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmModuleCodegenContext.kt +++ b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmModuleCodegenContext.kt
@@ -83,6 +83,10 @@ wasmFragment.functions.define(irFunction, wasmFunction) } + fun defineHotswapFunction(irFunction: IrFunctionSymbol, wasmFunction: WasmFunction) { + wasmFragment.hotswapFunctions.define(irFunction, wasmFunction) + } + fun defineGlobalField(irField: IrFieldSymbol, wasmGlobal: WasmGlobal) { wasmFragment.globalFields.define(irField, wasmGlobal) } @@ -107,6 +111,20 @@ wasmFragment.functionTypes.define(irFunction, wasmFunctionType) } + fun defineHotswapFieldGetterFunctionType(irField: IrFieldSymbol, wasmFunctionType: WasmFunctionType) { + wasmFragment.hotSwapFieldGetterBridgesFunctionTypes.define(irField, wasmFunctionType) + } + + fun defineHotswapFieldSetterFunctionType(irField: IrFieldSymbol, wasmFunctionType: WasmFunctionType) { + wasmFragment.hotSwapFieldSetterBridgesFunctionTypes.define(irField, wasmFunctionType) + } + + fun defineHotswapFieldGetter(irField: IrFieldSymbol, wasmFunction: WasmFunction) = + wasmFragment.hotswapFieldGetter.define(irField, wasmFunction) + + fun defineHotswapFieldSetter(irField: IrFieldSymbol, wasmFunction: WasmFunction) = + wasmFragment.hotswapFieldSetter.define(irField, wasmFunction) + private val classMetadataCache = mutableMapOf<IrClassSymbol, ClassMetadata>() fun getClassMetadata(irClass: IrClassSymbol): ClassMetadata = classMetadataCache.getOrPut(irClass) { @@ -126,6 +144,15 @@ fun referenceFunction(irFunction: IrFunctionSymbol): WasmSymbol<WasmFunction> = wasmFragment.functions.reference(irFunction) + fun referenceHotswapTableIndex(irFunction: IrFunctionSymbol): WasmSymbol<Int> = + wasmFragment.hotswapFunctionIndexes.reference(irFunction) + + fun referenceHotswapFieldGetterTableIndex(irField: IrFieldSymbol): WasmSymbol<Int> = + wasmFragment.hotswapFieldGetterIndexes.reference(irField) + + fun referenceHotswapFieldSetterTableIndex(irField: IrFieldSymbol): WasmSymbol<Int> = + wasmFragment.hotswapFieldSetterIndexes.reference(irField) + fun referenceGlobalField(irField: IrFieldSymbol): WasmSymbol<WasmGlobal> = wasmFragment.globalFields.reference(irField) @@ -177,6 +204,12 @@ fun referenceFunctionType(irFunction: IrFunctionSymbol): WasmSymbol<WasmFunctionType> = wasmFragment.functionTypes.reference(irFunction) + fun referenceHotswapFieldGetterFunctionType(irField: IrFieldSymbol): WasmSymbol<WasmFunctionType> = + wasmFragment.hotSwapFieldGetterBridgesFunctionTypes.reference(irField) + + fun referenceHotswapFieldSetterFunctionType(irField: IrFieldSymbol): WasmSymbol<WasmFunctionType> = + wasmFragment.hotSwapFieldSetterBridgesFunctionTypes.reference(irField) + fun referenceClassId(irClass: IrClassSymbol): WasmSymbol<Int> = wasmFragment.classIds.reference(irClass)
diff --git a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmModuleFragmentGenerator.kt b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmModuleFragmentGenerator.kt index 74be074..95edfe1 100644 --- a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmModuleFragmentGenerator.kt +++ b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/ir2wasm/WasmModuleFragmentGenerator.kt
@@ -58,6 +58,7 @@ } fun generateModule(irModuleFragment: IrModuleFragment) { + declarationGenerator.isHotSwapEnabledModule = irModuleFragment.name.asString() == "<main>" acceptVisitor(irModuleFragment, declarationGenerator) }
diff --git a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/FieldInitializersLowering.kt b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/FieldInitializersLowering.kt index c163b9e..c7a4c3e 100644 --- a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/FieldInitializersLowering.kt +++ b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/FieldInitializersLowering.kt
@@ -34,6 +34,7 @@ override fun lower(irFile: IrFile) { val builder = context.createIrBuilder(context.fieldInitFunction.symbol) val startFunctionBody = context.fieldInitFunction.body as IrBlockBody + val hotSwapInitFunctionBody = context.hotSwapFieldInitFunction.body as IrBlockBody irFile.acceptChildrenVoid(object : IrElementVisitorVoid { override fun visitElement(element: IrElement) { @@ -57,7 +58,7 @@ val initializerStatement = builder.at(initValue).irSetField(null, declaration, initValue) when (declaration.fqNameWhenAvailable) { - stringPoolFqName -> startFunctionBody.statements.add(0, initializerStatement) + stringPoolFqName -> hotSwapInitFunctionBody.statements.add(0, initializerStatement) else -> startFunctionBody.statements.add(initializerStatement) }
diff --git a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/JsInteropFunctionsLowering.kt b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/JsInteropFunctionsLowering.kt index 6e13cd6..9dd4c5c 100644 --- a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/JsInteropFunctionsLowering.kt +++ b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/JsInteropFunctionsLowering.kt
@@ -621,7 +621,7 @@ } val hashString: String = - functionType.hashCode().absoluteValue.toString(Character.MAX_RADIX) + functionType.render().hashCode().absoluteValue.toString(Character.MAX_RADIX) val originalParameterTypes: List<IrType> = functionType.arguments.dropLast(1).map { (it as IrTypeProjection).type }
diff --git a/compiler/testData/codegen/box/strings/surrogatePair.kt b/compiler/testData/codegen/box/strings/surrogatePair.kt index e220ce5..bb5d8fa 100644 --- a/compiler/testData/codegen/box/strings/surrogatePair.kt +++ b/compiler/testData/codegen/box/strings/surrogatePair.kt
@@ -1,8 +1,45 @@ -// Will be executed on JDK 9, 11, 17 -fun test(s: String): String { - return "\ud83c" + s + "\udf09"; -} +import kotlinx.browser.* +import org.w3c.dom.HTMLElement +import org.w3c.dom.HTMLButtonElement +import kotlinx.dom.* +import org.w3c.dom.events.MouseEvent + +val a: Int by lazy { 2 } + +var count = 1 + +const val q = 2 + +@JsFun("() => 'kek'") +external fun externLol(): String fun box() : String { - return if (test("") == "\ud83c\udf09") "OK" else "fail: ${test("")}" -} \ No newline at end of file + val test = document.body!! as HTMLElement + val newElement = test.appendElement("div") { + val inner = appendElement("div") { + textContent = count.toString() + } + appendElement("button") { + this as HTMLButtonElement + textContent = "GO" + onclick = { + update() + inner.textContent = count.toString() + null + } + } + } + + return "OK" +} + +fun HTMLElement.appendElement(text: String) = appendElement(text) {} + +fun update() { + count += a + val doc = document.body!! as HTMLElement + doc.appendElement("h1") { + textContent = count.toString() + } +} +
diff --git a/libraries/stdlib/wasm/internal/kotlin/wasm/internal/Runtime.kt b/libraries/stdlib/wasm/internal/kotlin/wasm/internal/Runtime.kt index 66be68b..52fe0b0 100644 --- a/libraries/stdlib/wasm/internal/kotlin/wasm/internal/Runtime.kt +++ b/libraries/stdlib/wasm/internal/kotlin/wasm/internal/Runtime.kt
@@ -116,4 +116,13 @@ // This initializer is a special case in FieldInitializersLowering @EagerInitialization -internal val stringPool: Array<String?> = arrayOfNulls(stringGetPoolSize()) \ No newline at end of file +internal val stringPool: Array<String?> = arrayOfNulls(stringGetPoolSize()) + +@ExcludedFromCodegen +internal fun initiateHotReload(): Unit = + implementedAsIntrinsic + +@JsExport +fun makeHotSwap() { + initiateHotReload() +} \ No newline at end of file
diff --git a/wasm/wasm.ir/src/org/jetbrains/kotlin/wasm/ir/convertors/WasmIrToBinary.kt b/wasm/wasm.ir/src/org/jetbrains/kotlin/wasm/ir/convertors/WasmIrToBinary.kt index 3cd8df4..0f16fba 100644 --- a/wasm/wasm.ir/src/org/jetbrains/kotlin/wasm/ir/convertors/WasmIrToBinary.kt +++ b/wasm/wasm.ir/src/org/jetbrains/kotlin/wasm/ir/convertors/WasmIrToBinary.kt
@@ -154,7 +154,7 @@ appendVectorSize(definedFunctions.size) definedFunctions.forEach { appendModuleFieldReference(it) - b.writeString(it.name) + b.writeString("${it.name}_${it.id}") } } appendSection(2u) { @@ -164,7 +164,7 @@ appendVectorSize(it.locals.size) it.locals.forEach { local -> b.writeVarUInt32(local.id) - b.writeString(local.name) + b.writeString("${local.name}_${local.id}") } } }