in progress
diff --git a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/phaser/PhaseFactories.kt b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/phaser/PhaseFactories.kt index c3a6bc3..e01918f 100644 --- a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/phaser/PhaseFactories.kt +++ b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/phaser/PhaseFactories.kt
@@ -19,7 +19,7 @@ annotation class PhaseDescription( val name: String, - val prerequisite: Array<KClass<out FileLoweringPass>> = [], + val prerequisite: Array<KClass<out ModuleLoweringPass>> = [], ) fun <Context : LoweringContext> createFilePhases(
diff --git a/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/JvmIrInliner.kt b/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/JvmIrInliner.kt index c3c3fa7..38dea2a 100644 --- a/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/JvmIrInliner.kt +++ b/compiler/ir/backend.jvm/lower/src/org/jetbrains/kotlin/backend/jvm/lower/JvmIrInliner.kt
@@ -10,6 +10,8 @@ import org.jetbrains.kotlin.backend.jvm.ir.isInlineFunctionCall import org.jetbrains.kotlin.ir.declarations.IrFile import org.jetbrains.kotlin.ir.declarations.IrFunction +import org.jetbrains.kotlin.ir.declarations.IrModuleFragment +import org.jetbrains.kotlin.ir.inline.AbstractInlineFunctionResolver import org.jetbrains.kotlin.ir.inline.FunctionInlining import org.jetbrains.kotlin.ir.inline.InlineFunctionResolver import org.jetbrains.kotlin.ir.inline.InlineMode @@ -27,13 +29,13 @@ ) { private val enabled = context.config.enableIrInliner - override fun lower(irFile: IrFile) { + override fun lower(irModule: IrModuleFragment) { if (enabled) { - super.lower(irFile) + super.lower(irModule) } } } -class JvmInlineFunctionResolver(private val context: JvmBackendContext) : InlineFunctionResolver(InlineMode.ALL_INLINE_FUNCTIONS) { +class JvmInlineFunctionResolver(private val context: JvmBackendContext) : AbstractInlineFunctionResolver(InlineMode.ALL_INLINE_FUNCTIONS) { override fun needsInlining(symbol: IrFunctionSymbol): Boolean = symbol.owner.isInlineFunctionCall(context) }
diff --git a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/WasmFunctionInlining.kt b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/WasmFunctionInlining.kt index 1eb2ae7..ca0bbbb 100644 --- a/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/WasmFunctionInlining.kt +++ b/compiler/ir/backend.wasm/src/org/jetbrains/kotlin/backend/wasm/lower/WasmFunctionInlining.kt
@@ -21,7 +21,7 @@ context = context, inlineFunctionResolver = WasmInlineFunctionResolver(context, inlineMode), produceOuterThisFields = false, - ).inline(irModule) + ).lower(irModule) irModule.patchDeclarationParents() }
diff --git a/compiler/ir/ir.inline/src/org/jetbrains/kotlin/ir/inline/FunctionInlining.kt b/compiler/ir/ir.inline/src/org/jetbrains/kotlin/ir/inline/FunctionInlining.kt index af37b9d..3d18db4 100644 --- a/compiler/ir/ir.inline/src/org/jetbrains/kotlin/ir/inline/FunctionInlining.kt +++ b/compiler/ir/ir.inline/src/org/jetbrains/kotlin/ir/inline/FunctionInlining.kt
@@ -21,7 +21,6 @@ import org.jetbrains.kotlin.ir.expressions.impl.* import org.jetbrains.kotlin.ir.originalBeforeInline import org.jetbrains.kotlin.ir.symbols.* -import org.jetbrains.kotlin.ir.symbols.impl.IrReturnableBlockSymbolImpl import org.jetbrains.kotlin.ir.types.* import org.jetbrains.kotlin.ir.util.* import org.jetbrains.kotlin.ir.visitors.* @@ -35,22 +34,108 @@ private val insertAdditionalImplicitCasts: Boolean = true, private val regenerateInlinedAnonymousObjects: Boolean = false, private val produceOuterThisFields: Boolean = true, -) : IrElementTransformerVoidWithContext(), BodyLoweringPass { +) : ModuleLoweringPass, FileLoweringPass { + class WrapperResolver( + val cached: Map<IrFunctionSymbol, IrFunction>, + val delegate: InlineFunctionResolver, + ) : InlineFunctionResolver by delegate { + override fun getFunctionDeclaration(symbol: IrFunctionSymbol): IrFunction? { + return cached[symbol] ?: delegate.getFunctionDeclaration(symbol) + } + } + + private fun lower(container: IrElement) { + val dependencies = mutableMapOf<IrFunctionSymbol, MutableList<IrFunctionSymbol>>() + container.acceptChildrenVoid(object : IrVisitorVoid() { + val insideInlineFunction = mutableListOf<IrFunctionSymbol>() + override fun visitFunction(declaration: IrFunction) { + if (inlineFunctionResolver.needsInlining(declaration.symbol)) { + insideInlineFunction.add(declaration.symbol) + super.visitFunction(declaration) + insideInlineFunction.removeLast() + } else { + super.visitFunction(declaration) + } + } + + override fun visitFunctionAccess(expression: IrFunctionAccessExpression) { + super.visitFunctionAccess(expression) + val callee = expression.symbol as? IrSimpleFunctionSymbol ?: return + if (inlineFunctionResolver.needsInlining(callee)) { + for (callSite in insideInlineFunction) { + dependencies.getOrPut(callSite) { mutableListOf() }.add(callee) + } + } + } + }) + val cache = mutableMapOf<IrFunctionSymbol, IrFunction>() + val inProgress = mutableSetOf<IrFunctionSymbol>() + val resolver = WrapperResolver(cache, inlineFunctionResolver) + val inlineTransformer = FunctionInliningTransformer( + context, + resolver, + insertAdditionalImplicitCasts, + regenerateInlinedAnonymousObjects, + produceOuterThisFields, + ) + + val inliningOrder = mutableListOf<IrFunctionSymbol>() + + fun computeOrder(callee: IrFunctionSymbol) { + if (cache.containsKey(callee) || callee !in dependencies.keys) return + if (!inProgress.add(callee)) { + TODO("Report recursive inlining") + } + for (dep in dependencies[callee] ?: emptyList()) { + computeOrder(dep) + } + inliningOrder.add(callee) + inProgress.remove(callee) + } + + for (callee in dependencies.keys) { + computeOrder(callee) + } + + for (callee in inliningOrder) { + val result = callee.owner.deepCopyWithSymbols(callee.owner.parent) + result.transformChildrenVoid(inlineTransformer) + // TODO: run erasure on result + cache[callee] = result // visible to future inlines in this loop + } + + container.transformChildrenVoid(inlineTransformer) + } + + override fun lower(irModule: IrModuleFragment) { + lower(irModule as IrElement) + } + + override fun lower(irFile: IrFile) { + lower(irFile as IrElement) + } + + fun lower(irFunction: IrFunction) { + lower(irFunction.body as IrElement) + for (parameter in irFunction.parameters) { + parameter.defaultValue?.let(::lower) + } + } +} + +class FunctionInliningTransformer( + val context: LoweringContext, + private val inlineFunctionResolver: InlineFunctionResolver, + private val insertAdditionalImplicitCasts: Boolean = true, + private val regenerateInlinedAnonymousObjects: Boolean = false, + private val produceOuterThisFields: Boolean = true, +) : IrElementTransformerVoidWithContext() { init { require(!produceOuterThisFields || context is CommonBackendContext) { "The inliner can generate outer fields only with param `context` of type `CommonBackendContext`" } } - override fun lower(irBody: IrBody, container: IrDeclaration) { - // TODO container: IrSymbolDeclaration - withinScope(container) { - irBody.accept(this, null) - } - - irBody.patchDeclarationParents(container as? IrDeclarationParent ?: container.parent) - } - override fun visitDeclaration(declaration: IrDeclarationBase): IrStatement { return when (declaration) { is IrFunction, is IrClass, is IrProperty -> context.irFactory.stageController.restrictTo(declaration) { @@ -60,8 +145,6 @@ } } - fun inline(irModule: IrModuleFragment) = irModule.accept(this, data = null) - override fun visitFunctionAccess(expression: IrFunctionAccessExpression): IrExpression { expression.transformChildrenVoid(this) @@ -82,16 +165,6 @@ return expression } - withinScope(actualCallee) { - actualCallee.body?.transformChildrenVoid() - actualCallee.parameters.forEachIndexed { index, param -> - if (expression.arguments[index] == null) { - // Default values can recursively reference [callee] - transform only needed. - param.defaultValue = param.defaultValue?.transform(this@FunctionInlining, null) - } - } - } - val parent = allScopes.map { it.irElement }.filterIsInstance<IrDeclarationParent>().lastOrNull() ?: allScopes.map { it.irElement }.filterIsInstance<IrDeclaration>().lastOrNull()?.parent
diff --git a/compiler/ir/ir.inline/src/org/jetbrains/kotlin/ir/inline/InlineFunctionResolver.kt b/compiler/ir/ir.inline/src/org/jetbrains/kotlin/ir/inline/InlineFunctionResolver.kt index 6a8be92..7bce198 100644 --- a/compiler/ir/ir.inline/src/org/jetbrains/kotlin/ir/inline/InlineFunctionResolver.kt +++ b/compiler/ir/ir.inline/src/org/jetbrains/kotlin/ir/inline/InlineFunctionResolver.kt
@@ -54,18 +54,27 @@ ALL_FUNCTIONS, } -abstract class InlineFunctionResolver(val inlineMode: InlineMode) { - open val callInlinerStrategy: CallInlinerStrategy +interface InlineFunctionResolver { + val inlineMode: InlineMode + val callInlinerStrategy: CallInlinerStrategy + val allowExternalInlining: Boolean + fun needsInlining(symbol: IrFunctionSymbol): Boolean + fun needsInlining(expression: IrFunctionAccessExpression): Boolean + fun getFunctionDeclaration(symbol: IrFunctionSymbol): IrFunction? +} + +abstract class AbstractInlineFunctionResolver(override val inlineMode: InlineMode) : InlineFunctionResolver { + override val callInlinerStrategy: CallInlinerStrategy get() = CallInlinerStrategy.DEFAULT - open val allowExternalInlining: Boolean + override val allowExternalInlining: Boolean get() = false - open fun needsInlining(symbol: IrFunctionSymbol) = + override fun needsInlining(symbol: IrFunctionSymbol) = symbol.isBound && symbol.owner.isInline && (allowExternalInlining || !symbol.owner.isExternal) - open fun needsInlining(expression: IrFunctionAccessExpression) = needsInlining(expression.symbol) + override fun needsInlining(expression: IrFunctionAccessExpression) = needsInlining(expression.symbol) - open fun getFunctionDeclaration(symbol: IrFunctionSymbol): IrFunction? { + override fun getFunctionDeclaration(symbol: IrFunctionSymbol): IrFunction? { if (shouldExcludeFunctionFromInlining(symbol)) return null val owner = symbol.owner @@ -80,7 +89,7 @@ abstract class InlineFunctionResolverReplacingCoroutineIntrinsics<Ctx : LoweringContext>( protected val context: Ctx, inlineMode: InlineMode, -) : InlineFunctionResolver(inlineMode) { +) : AbstractInlineFunctionResolver(inlineMode) { final override val allowExternalInlining: Boolean get() = context.allowExternalInlining
diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/driver/phases/NativeLoweringPhases.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/driver/phases/NativeLoweringPhases.kt index 0f59c41..633bc52 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/driver/phases/NativeLoweringPhases.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/driver/phases/NativeLoweringPhases.kt
@@ -363,11 +363,11 @@ /** * The second phase of inlining (inline all functions). */ -internal val inlineAllFunctionsPhase = createFileLoweringPhase( - lowering = { context: Context -> - NativeIrInliner(context, inlineMode = InlineMode.ALL_INLINE_FUNCTIONS) - }, +internal val inlineAllFunctionsPhase = createSimpleNamedCompilerPhase<NativeGenerationState, IrModuleFragment>( name = "InlineAllFunctions", + op = { context: NativeGenerationState, module -> + NativeIrInliner(context.context, inlineMode = InlineMode.ALL_INLINE_FUNCTIONS).lower(module) + } ) private val interopPhase = createFileLoweringPhase(
diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/driver/phases/TopLevelPhases.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/driver/phases/TopLevelPhases.kt index 434066c..a70d854 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/driver/phases/TopLevelPhases.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/driver/phases/TopLevelPhases.kt
@@ -160,7 +160,7 @@ // invariant, we would like to put a synchronization point immediately before "InlineAllFunctions". fragmentWithState.forEach { (fragment, state) -> state.runSpecifiedLowerings(fragment, getLoweringsUpToAndIncludingSyntheticAccessors()) } fragmentWithState.forEach { (fragment, state) -> state.runSpecifiedLowerings(fragment, validateIrAfterInliningOnlyPrivateFunctions) } - fragmentWithState.forEach { (fragment, state) -> state.runSpecifiedLowerings(fragment, listOf(inlineAllFunctionsPhase)) } + fragmentWithState.forEach { (fragment, state) -> state.runSpecifiedLowerings(fragment, inlineAllFunctionsPhase) } if (context.config.configuration[KlibConfigurationKeys.SYNTHETIC_ACCESSORS_DUMP_DIR] != null) { fragmentWithState.forEach { (fragment, state) -> state.runSpecifiedLowerings(fragment, dumpSyntheticAccessorsPhase) } }
diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/NativeInlineFunctionResolver.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/NativeInlineFunctionResolver.kt index 9e066a4..27ff0dca 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/NativeInlineFunctionResolver.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/lower/NativeInlineFunctionResolver.kt
@@ -61,7 +61,7 @@ ArrayConstructorLowering(context).lower(body, function) - NativeIrInliner(context, inlineMode = InlineMode.PRIVATE_INLINE_FUNCTIONS).lower(body, function) + NativeIrInliner(context, inlineMode = InlineMode.PRIVATE_INLINE_FUNCTIONS).lower(function) OuterThisInInlineFunctionsSpecialAccessorLowering(context).lowerWithoutAddingAccessorsToParents(function) SyntheticAccessorLowering(context).lowerWithoutAddingAccessorsToParents(function) }
diff --git a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/PreCodegenInliner.kt b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/PreCodegenInliner.kt index f6030e2..96b639a 100644 --- a/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/PreCodegenInliner.kt +++ b/kotlin-native/backend.native/compiler/ir/backend.native/src/org/jetbrains/kotlin/backend/konan/optimizations/PreCodegenInliner.kt
@@ -20,12 +20,15 @@ import org.jetbrains.kotlin.ir.expressions.IrCall import org.jetbrains.kotlin.ir.expressions.IrFunctionAccessExpression import org.jetbrains.kotlin.ir.expressions.IrSuspensionPoint +import org.jetbrains.kotlin.ir.inline.AbstractInlineFunctionResolver import org.jetbrains.kotlin.ir.inline.FunctionInlining +import org.jetbrains.kotlin.ir.inline.FunctionInliningTransformer import org.jetbrains.kotlin.ir.inline.InlineFunctionResolver import org.jetbrains.kotlin.ir.inline.InlineMode import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol import org.jetbrains.kotlin.ir.types.classOrNull import org.jetbrains.kotlin.ir.util.* +import org.jetbrains.kotlin.ir.visitors.transformChildrenVoid import org.jetbrains.kotlin.library.metadata.isCInteropLibrary import kotlin.collections.* @@ -146,9 +149,9 @@ } if (functionsToInline.isNotEmpty()) { - val inliner = FunctionInlining( + val inliner = FunctionInliningTransformer( context, - inlineFunctionResolver = object : InlineFunctionResolver(inlineMode = InlineMode.ALL_FUNCTIONS) { + inlineFunctionResolver = object : AbstractInlineFunctionResolver(inlineMode = InlineMode.ALL_FUNCTIONS) { override fun shouldExcludeFunctionFromInlining(symbol: IrFunctionSymbol) = symbol.owner !in functionsToInline @@ -157,7 +160,7 @@ } }, ) - inliner.lower(irBody, irFunction) + irFunction.transform(inliner, null) // KT-72336: This is not entirely correct since coroutinesLivenessAnalysisPhase could be turned off. LivenessAnalysis.run(irBody) { it is IrSuspensionPoint }