Effect analysis: a more clever handling of singleton initializers
diff --git a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt index b08c748..13b4b55 100644 --- a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt +++ b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt
@@ -25,28 +25,47 @@ import org.jetbrains.kotlin.name.FqName import org.jetbrains.kotlin.utils.addToStdlib.safeAs -enum class SideEffects { +sealed class SideEffects(val level: Int, val name: String) { /** - * Aka 'pure'. [READNONE] expressions can be reordered or eliminated. + * Aka 'pure'. [ReadNone] expressions can be reordered or eliminated. */ - READNONE, + object ReadNone : SideEffects(0, "READNONE") /** - * May read the global state, may not alter it. Calls to [READONLY] functions cannot be reordered, but can be eliminated. + * May read the global state, may not alter it. Calls to [ReadOnly] functions cannot be reordered, but can be eliminated. */ - READONLY, + object ReadOnly : SideEffects(1, "READONLY") /** - * A special kind of effect. Used to mark singleton initializers that are otherwise would be considered [READWRITE], + * A special kind of effect. Used to mark singleton initializers that are otherwise would be considered [ReadWrite], * but don't have any effects except saving the instance to a global variable. + * + * [otherwise] is the effects that a function would have if we removed all the [AlmostPureSingletonConstructor] effects from it. */ - ALMOST_PURE_SINGLETON_CONSTRUCTOR, + data class AlmostPureSingletonConstructor(val otherwise: SideEffects) : + SideEffects(2 + otherwise.level, "ALMOST_PURE_SINGLETON_CONSTRUCTOR") + { + init { + require(otherwise !is AlmostPureSingletonConstructor) + } + } /** * Can arbitrarily alter the global state. */ - READWRITE, + object ReadWrite : SideEffects(Int.MAX_VALUE, "READWRITE") + + fun isAtMost(other: SideEffects) = level <= other.level + + companion object { + fun valueOf(s: String) = when (s) { + ReadNone.name -> ReadNone + ReadOnly.name -> ReadOnly + ReadWrite.name -> ReadWrite + else -> throw IllegalArgumentException("$s is not a valid side effect name") + } + } } private val effectsAnnotationFqName = FqName("kotlin.internal.Effects") @@ -61,8 +80,8 @@ } fun IrFunction.addEffectsAnnotation(effects: SideEffects, context: CommonBackendContext) { - if (effects == SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR) { - error("${SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR.name} cannot be set via an annotation!") + if (effects is SideEffects.AlmostPureSingletonConstructor) { + error("${effects.name} cannot be set via an annotation!") } val annotationClassSymbol = context.getClassSymbol(effectsAnnotationFqName) val enumClassSymbol = context.getClassSymbol(effectEnumFqName) @@ -83,7 +102,7 @@ memoizer: FunctionSideEffectMemoizer = mutableMapOf(), context: CommonBackendContext? = null, ): SideEffects = - this?.accept(EffectAnalyzer(anyVariableReadIsPure, memoizer, context), Unit) ?: SideEffects.READNONE + this?.accept(EffectAnalyzer(anyVariableReadIsPure, memoizer, context), Unit) ?: SideEffects.ReadNone fun IrFunction.computeEffects( anyVariableReadIsPure: Boolean, @@ -100,23 +119,50 @@ ) = analyzer.memoizer.getOrPut(symbol) { if (analyzer.context != null && this is IrConstructor) { if (symbol.owner.constructedClass.symbol == analyzer.context.irBuiltIns.anyClass) { - return@getOrPut SideEffects.READNONE + return@getOrPut SideEffects.ReadNone } } - getDeclaredEffects() + val effects = getDeclaredEffects() ?: body?.accept(analyzer, Unit) - ?: SideEffects.READWRITE + ?: SideEffects.ReadWrite + + // TODO: Good heuristic for Kotlin/JS, but will it work on other backends? + if (effects is SideEffects.AlmostPureSingletonConstructor && this !is IrConstructor) + return effects.otherwise + + effects } -fun IrExpression?.isPure(anyVariableReadIsPure: Boolean) = computeEffects(anyVariableReadIsPure) == SideEffects.READNONE +fun IrExpression?.isPure(anyVariableReadIsPure: Boolean) = computeEffects(anyVariableReadIsPure) == SideEffects.ReadNone private inline fun <T> Iterable<T>.maxEffect(computeEffects: (T) -> SideEffects): SideEffects { - return maxOfOrNull { - computeEffects(it).also { result -> - // Early exit to avoid expensive computations if we already know that we're going to get READWRITE. - if (result == SideEffects.READWRITE) return SideEffects.READWRITE - } - } ?: SideEffects.READNONE + val iterator = iterator() + if (!iterator.hasNext()) return SideEffects.ReadNone + var maxValue = computeEffects(iterator.next()) + while (iterator.hasNext()) { + if (maxValue is SideEffects.ReadWrite) return maxValue + val v = computeEffects(iterator.next()) + maxValue = maxEffectOf(maxValue, v) + } + return maxValue +} + +private fun maxEffectOf(a: SideEffects, b: SideEffects): SideEffects { + if (a is SideEffects.ReadWrite || b is SideEffects.ReadWrite) return SideEffects.ReadWrite + + if (a is SideEffects.AlmostPureSingletonConstructor && b is SideEffects.AlmostPureSingletonConstructor) { + return SideEffects.AlmostPureSingletonConstructor(maxEffectOf(a.otherwise, b.otherwise)) + } + + if (a is SideEffects.AlmostPureSingletonConstructor) { + return if (a.otherwise.level < b.level) SideEffects.AlmostPureSingletonConstructor(b) else a + } + + if (b is SideEffects.AlmostPureSingletonConstructor) { + return maxEffectOf(b, a) + } + + return if (a.level >= b.level) a else b } private class EffectAnalyzer( @@ -132,32 +178,32 @@ } override fun visitElement(element: IrElement, data: Unit): SideEffects { - return SideEffects.READWRITE + return SideEffects.ReadWrite } override fun visitExpression(expression: IrExpression, data: Unit): SideEffects { - return SideEffects.READWRITE + return SideEffects.ReadWrite } override fun visitFunction(declaration: IrFunction, data: Unit): SideEffects { // Function declarations themselves have no effects. - return SideEffects.READNONE + return SideEffects.ReadNone } override fun visitClass(declaration: IrClass, data: Unit): SideEffects { - return SideEffects.READNONE + return SideEffects.ReadNone } override fun visitTypeAlias(declaration: IrTypeAlias, data: Unit): SideEffects { - return SideEffects.READNONE + return SideEffects.ReadNone } override fun visitVariable(declaration: IrVariable, data: Unit): SideEffects { - return declaration.initializer?.accept(this, data) ?: SideEffects.READNONE + return declaration.initializer?.accept(this, data) ?: SideEffects.ReadNone } override fun visitBody(body: IrBody, data: Unit): SideEffects { - return SideEffects.READWRITE + return SideEffects.ReadWrite } override fun visitBlockBody(body: IrBlockBody, data: Unit): SideEffects { @@ -170,9 +216,9 @@ override fun visitSyntheticBody(body: IrSyntheticBody, data: Unit): SideEffects { return when (body.kind) { - IrSyntheticBodyKind.ENUM_VALUES -> SideEffects.READNONE - IrSyntheticBodyKind.ENUM_VALUEOF -> SideEffects.READNONE - IrSyntheticBodyKind.ENUM_ENTRIES -> SideEffects.READNONE + IrSyntheticBodyKind.ENUM_VALUES -> SideEffects.ReadNone + IrSyntheticBodyKind.ENUM_VALUEOF -> SideEffects.ReadNone + IrSyntheticBodyKind.ENUM_ENTRIES -> SideEffects.ReadNone } } @@ -185,57 +231,57 @@ } override fun visitConst(expression: IrConst<*>, data: Unit): SideEffects { - return SideEffects.READNONE + return SideEffects.ReadNone } override fun visitGetValue(expression: IrGetValue, data: Unit): SideEffects { - if (anyVariableReadIsPure) return SideEffects.READNONE + if (anyVariableReadIsPure) return SideEffects.ReadNone val valueDeclaration = expression.symbol.owner val isPure = if (valueDeclaration is IrVariable) !valueDeclaration.isVar else true - return if (isPure) SideEffects.READNONE else SideEffects.READWRITE + return if (isPure) SideEffects.ReadNone else SideEffects.ReadWrite } override fun visitTypeOperator(expression: IrTypeOperatorCall, data: Unit): SideEffects { return if (expression.operator !in setOf(IrTypeOperator.INSTANCEOF, IrTypeOperator.REINTERPRET_CAST, IrTypeOperator.NOT_INSTANCEOF)) - SideEffects.READWRITE + SideEffects.ReadWrite else expression.argument.computeEffects(anyVariableReadIsPure) } override fun visitGetObjectValue(expression: IrGetObjectValue, data: Unit): SideEffects { - return expression.symbol.owner.primaryConstructor?.accept(this, data) ?: SideEffects.READNONE + return expression.symbol.owner.primaryConstructor?.accept(this, data) ?: SideEffects.ReadNone } override fun visitGetField(expression: IrGetField, data: Unit): SideEffects { if (!expression.symbol.owner.isFinal && !anyVariableReadIsPure) { - return SideEffects.READWRITE + return SideEffects.ReadWrite } - return expression.receiver?.accept(this, data) ?: SideEffects.READONLY + return expression.receiver?.accept(this, data) ?: SideEffects.ReadOnly } override fun visitSetField(expression: IrSetField, data: Unit): SideEffects { val valueEffect = expression.value.accept(this, data) - if (valueEffect == SideEffects.READWRITE) return SideEffects.READWRITE + if (valueEffect == SideEffects.ReadWrite) return SideEffects.ReadWrite - val constructorSymbol = callStack.lastOrNull() as? IrConstructorSymbol? ?: return SideEffects.READWRITE + val constructorSymbol = callStack.lastOrNull() as? IrConstructorSymbol? ?: return SideEffects.ReadWrite if (expression.symbol.owner.origin == IrDeclarationOrigin.FIELD_FOR_OBJECT_INSTANCE) { - return SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR + return SideEffects.AlmostPureSingletonConstructor(valueEffect) } // If we are in a constructor, and we're setting a constructed instance's field, treat it as a READNONE operation. - val receiver = expression.receiver as? IrGetValue ?: return SideEffects.READWRITE - val valueParameter = receiver.symbol.owner as? IrValueParameter ?: return SideEffects.READWRITE - if (!valueParameter.isDispatchReceiver) return SideEffects.READWRITE + val receiver = expression.receiver as? IrGetValue ?: return SideEffects.ReadWrite + val valueParameter = receiver.symbol.owner as? IrValueParameter ?: return SideEffects.ReadWrite + if (!valueParameter.isDispatchReceiver) return SideEffects.ReadWrite val assignmentEffect = - if (valueParameter.parent == constructorSymbol.owner.constructedClass) SideEffects.READNONE else SideEffects.READWRITE - return maxOf(valueEffect, assignmentEffect) + if (valueParameter.parent == constructorSymbol.owner.constructedClass) SideEffects.ReadNone else SideEffects.ReadWrite + return maxEffectOf(valueEffect, assignmentEffect) } override fun visitVararg(expression: IrVararg, data: Unit): SideEffects { return expression.elements.maxEffect { - (it as? IrExpression)?.accept(this, data) ?: SideEffects.READWRITE + (it as? IrExpression)?.accept(this, data) ?: SideEffects.ReadWrite } } @@ -245,7 +291,7 @@ if (callStack.contains(function.symbol)) { // Consider recursive calls non-pure // A more precise analysis can be done, but for now we stick with this. - return SideEffects.READWRITE + return SideEffects.ReadWrite } callStack.push(function.symbol) @@ -253,13 +299,13 @@ try { val functionSideEffects = function.computeEffectsImpl(this) - if (functionSideEffects == SideEffects.READWRITE) return SideEffects.READWRITE + if (functionSideEffects == SideEffects.ReadWrite) return SideEffects.ReadWrite val argComputationSideEffects = (0 until expression.valueArgumentsCount).maxEffect { expression.getValueArgument(it)!!.accept(this, data) } - return maxOf(functionSideEffects, argComputationSideEffects) + return maxEffectOf(functionSideEffects, argComputationSideEffects) } finally { callStack.pop() } @@ -270,11 +316,11 @@ } override fun visitBreakContinue(jump: IrBreakContinue, data: Unit): SideEffects { - return SideEffects.READNONE + return SideEffects.ReadNone } override fun visitFunctionExpression(expression: IrFunctionExpression, data: Unit): SideEffects { - return SideEffects.READNONE + return SideEffects.ReadNone } override fun visitWhen(expression: IrWhen, data: Unit): SideEffects {
diff --git a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt index 91462c0..45de2b2 100644 --- a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt +++ b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt
@@ -164,7 +164,7 @@ type = local.type } } - outlinedFunction.addEffectsAnnotation(SideEffects.READWRITE, backendContext) + outlinedFunction.addEffectsAnnotation(SideEffects.ReadWrite, backendContext) // Building JS Ast function val lastStatement = jsStatements.findLast { it !is JsSingleLineComment && it !is JsMultiLineComment }
diff --git a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt index c8af1d2..a21de50 100644 --- a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt +++ b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt
@@ -11,7 +11,6 @@ import org.jetbrains.kotlin.backend.common.ir.computeEffects import org.jetbrains.kotlin.ir.IrElement import org.jetbrains.kotlin.ir.IrStatement -import org.jetbrains.kotlin.ir.backend.js.JsIrBackendContext import org.jetbrains.kotlin.ir.declarations.IrDeclaration import org.jetbrains.kotlin.ir.expressions.* import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol @@ -70,13 +69,10 @@ var unreachable = false for (statement in statements) { if (statement is IrFunctionAccessExpression) { - val functionSideEffect = statement.symbol.owner.computeEffects(true, functionSideEffectMemoizer, context).let { - if (it == SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR && statement !is IrConstructorCall) - SideEffects.READNONE - else it - } + val functionSideEffect = statement.symbol.owner.computeEffects(true, functionSideEffectMemoizer, context) - if (functionSideEffect <= SideEffects.READONLY) { + if (functionSideEffect.isAtMost(SideEffects.ReadOnly)) { + // Eliminate the call but keep the arguments, they can have effects. for (i in 0 until statement.valueArgumentsCount) { add(statement.getValueArgument(i)!!) } @@ -89,7 +85,7 @@ statement is IrExpression && (statement.computeEffects( true, functionSideEffectMemoizer - ) <= SideEffects.READONLY) -> false + ).isAtMost(SideEffects.ReadOnly)) -> false unreachable -> false else -> { unreachable = statement.doesNotReturn()