Effect analysis: a more clever handling of singleton initializers
diff --git a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt
index b08c748..13b4b55 100644
--- a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt
+++ b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt
@@ -25,28 +25,47 @@
import org.jetbrains.kotlin.name.FqName
import org.jetbrains.kotlin.utils.addToStdlib.safeAs
-enum class SideEffects {
+sealed class SideEffects(val level: Int, val name: String) {
/**
- * Aka 'pure'. [READNONE] expressions can be reordered or eliminated.
+ * Aka 'pure'. [ReadNone] expressions can be reordered or eliminated.
*/
- READNONE,
+ object ReadNone : SideEffects(0, "READNONE")
/**
- * May read the global state, may not alter it. Calls to [READONLY] functions cannot be reordered, but can be eliminated.
+ * May read the global state, may not alter it. Calls to [ReadOnly] functions cannot be reordered, but can be eliminated.
*/
- READONLY,
+ object ReadOnly : SideEffects(1, "READONLY")
/**
- * A special kind of effect. Used to mark singleton initializers that are otherwise would be considered [READWRITE],
+ * A special kind of effect. Used to mark singleton initializers that are otherwise would be considered [ReadWrite],
* but don't have any effects except saving the instance to a global variable.
+ *
+ * [otherwise] is the effects that a function would have if we removed all the [AlmostPureSingletonConstructor] effects from it.
*/
- ALMOST_PURE_SINGLETON_CONSTRUCTOR,
+ data class AlmostPureSingletonConstructor(val otherwise: SideEffects) :
+ SideEffects(2 + otherwise.level, "ALMOST_PURE_SINGLETON_CONSTRUCTOR")
+ {
+ init {
+ require(otherwise !is AlmostPureSingletonConstructor)
+ }
+ }
/**
* Can arbitrarily alter the global state.
*/
- READWRITE,
+ object ReadWrite : SideEffects(Int.MAX_VALUE, "READWRITE")
+
+ fun isAtMost(other: SideEffects) = level <= other.level
+
+ companion object {
+ fun valueOf(s: String) = when (s) {
+ ReadNone.name -> ReadNone
+ ReadOnly.name -> ReadOnly
+ ReadWrite.name -> ReadWrite
+ else -> throw IllegalArgumentException("$s is not a valid side effect name")
+ }
+ }
}
private val effectsAnnotationFqName = FqName("kotlin.internal.Effects")
@@ -61,8 +80,8 @@
}
fun IrFunction.addEffectsAnnotation(effects: SideEffects, context: CommonBackendContext) {
- if (effects == SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR) {
- error("${SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR.name} cannot be set via an annotation!")
+ if (effects is SideEffects.AlmostPureSingletonConstructor) {
+ error("${effects.name} cannot be set via an annotation!")
}
val annotationClassSymbol = context.getClassSymbol(effectsAnnotationFqName)
val enumClassSymbol = context.getClassSymbol(effectEnumFqName)
@@ -83,7 +102,7 @@
memoizer: FunctionSideEffectMemoizer = mutableMapOf(),
context: CommonBackendContext? = null,
): SideEffects =
- this?.accept(EffectAnalyzer(anyVariableReadIsPure, memoizer, context), Unit) ?: SideEffects.READNONE
+ this?.accept(EffectAnalyzer(anyVariableReadIsPure, memoizer, context), Unit) ?: SideEffects.ReadNone
fun IrFunction.computeEffects(
anyVariableReadIsPure: Boolean,
@@ -100,23 +119,50 @@
) = analyzer.memoizer.getOrPut(symbol) {
if (analyzer.context != null && this is IrConstructor) {
if (symbol.owner.constructedClass.symbol == analyzer.context.irBuiltIns.anyClass) {
- return@getOrPut SideEffects.READNONE
+ return@getOrPut SideEffects.ReadNone
}
}
- getDeclaredEffects()
+ val effects = getDeclaredEffects()
?: body?.accept(analyzer, Unit)
- ?: SideEffects.READWRITE
+ ?: SideEffects.ReadWrite
+
+ // TODO: Good heuristic for Kotlin/JS, but will it work on other backends?
+ if (effects is SideEffects.AlmostPureSingletonConstructor && this !is IrConstructor)
+ return effects.otherwise
+
+ effects
}
-fun IrExpression?.isPure(anyVariableReadIsPure: Boolean) = computeEffects(anyVariableReadIsPure) == SideEffects.READNONE
+fun IrExpression?.isPure(anyVariableReadIsPure: Boolean) = computeEffects(anyVariableReadIsPure) == SideEffects.ReadNone
private inline fun <T> Iterable<T>.maxEffect(computeEffects: (T) -> SideEffects): SideEffects {
- return maxOfOrNull {
- computeEffects(it).also { result ->
- // Early exit to avoid expensive computations if we already know that we're going to get READWRITE.
- if (result == SideEffects.READWRITE) return SideEffects.READWRITE
- }
- } ?: SideEffects.READNONE
+ val iterator = iterator()
+ if (!iterator.hasNext()) return SideEffects.ReadNone
+ var maxValue = computeEffects(iterator.next())
+ while (iterator.hasNext()) {
+ if (maxValue is SideEffects.ReadWrite) return maxValue
+ val v = computeEffects(iterator.next())
+ maxValue = maxEffectOf(maxValue, v)
+ }
+ return maxValue
+}
+
+private fun maxEffectOf(a: SideEffects, b: SideEffects): SideEffects {
+ if (a is SideEffects.ReadWrite || b is SideEffects.ReadWrite) return SideEffects.ReadWrite
+
+ if (a is SideEffects.AlmostPureSingletonConstructor && b is SideEffects.AlmostPureSingletonConstructor) {
+ return SideEffects.AlmostPureSingletonConstructor(maxEffectOf(a.otherwise, b.otherwise))
+ }
+
+ if (a is SideEffects.AlmostPureSingletonConstructor) {
+ return if (a.otherwise.level < b.level) SideEffects.AlmostPureSingletonConstructor(b) else a
+ }
+
+ if (b is SideEffects.AlmostPureSingletonConstructor) {
+ return maxEffectOf(b, a)
+ }
+
+ return if (a.level >= b.level) a else b
}
private class EffectAnalyzer(
@@ -132,32 +178,32 @@
}
override fun visitElement(element: IrElement, data: Unit): SideEffects {
- return SideEffects.READWRITE
+ return SideEffects.ReadWrite
}
override fun visitExpression(expression: IrExpression, data: Unit): SideEffects {
- return SideEffects.READWRITE
+ return SideEffects.ReadWrite
}
override fun visitFunction(declaration: IrFunction, data: Unit): SideEffects {
// Function declarations themselves have no effects.
- return SideEffects.READNONE
+ return SideEffects.ReadNone
}
override fun visitClass(declaration: IrClass, data: Unit): SideEffects {
- return SideEffects.READNONE
+ return SideEffects.ReadNone
}
override fun visitTypeAlias(declaration: IrTypeAlias, data: Unit): SideEffects {
- return SideEffects.READNONE
+ return SideEffects.ReadNone
}
override fun visitVariable(declaration: IrVariable, data: Unit): SideEffects {
- return declaration.initializer?.accept(this, data) ?: SideEffects.READNONE
+ return declaration.initializer?.accept(this, data) ?: SideEffects.ReadNone
}
override fun visitBody(body: IrBody, data: Unit): SideEffects {
- return SideEffects.READWRITE
+ return SideEffects.ReadWrite
}
override fun visitBlockBody(body: IrBlockBody, data: Unit): SideEffects {
@@ -170,9 +216,9 @@
override fun visitSyntheticBody(body: IrSyntheticBody, data: Unit): SideEffects {
return when (body.kind) {
- IrSyntheticBodyKind.ENUM_VALUES -> SideEffects.READNONE
- IrSyntheticBodyKind.ENUM_VALUEOF -> SideEffects.READNONE
- IrSyntheticBodyKind.ENUM_ENTRIES -> SideEffects.READNONE
+ IrSyntheticBodyKind.ENUM_VALUES -> SideEffects.ReadNone
+ IrSyntheticBodyKind.ENUM_VALUEOF -> SideEffects.ReadNone
+ IrSyntheticBodyKind.ENUM_ENTRIES -> SideEffects.ReadNone
}
}
@@ -185,57 +231,57 @@
}
override fun visitConst(expression: IrConst<*>, data: Unit): SideEffects {
- return SideEffects.READNONE
+ return SideEffects.ReadNone
}
override fun visitGetValue(expression: IrGetValue, data: Unit): SideEffects {
- if (anyVariableReadIsPure) return SideEffects.READNONE
+ if (anyVariableReadIsPure) return SideEffects.ReadNone
val valueDeclaration = expression.symbol.owner
val isPure = if (valueDeclaration is IrVariable) !valueDeclaration.isVar
else true
- return if (isPure) SideEffects.READNONE else SideEffects.READWRITE
+ return if (isPure) SideEffects.ReadNone else SideEffects.ReadWrite
}
override fun visitTypeOperator(expression: IrTypeOperatorCall, data: Unit): SideEffects {
return if (expression.operator !in setOf(IrTypeOperator.INSTANCEOF, IrTypeOperator.REINTERPRET_CAST, IrTypeOperator.NOT_INSTANCEOF))
- SideEffects.READWRITE
+ SideEffects.ReadWrite
else
expression.argument.computeEffects(anyVariableReadIsPure)
}
override fun visitGetObjectValue(expression: IrGetObjectValue, data: Unit): SideEffects {
- return expression.symbol.owner.primaryConstructor?.accept(this, data) ?: SideEffects.READNONE
+ return expression.symbol.owner.primaryConstructor?.accept(this, data) ?: SideEffects.ReadNone
}
override fun visitGetField(expression: IrGetField, data: Unit): SideEffects {
if (!expression.symbol.owner.isFinal && !anyVariableReadIsPure) {
- return SideEffects.READWRITE
+ return SideEffects.ReadWrite
}
- return expression.receiver?.accept(this, data) ?: SideEffects.READONLY
+ return expression.receiver?.accept(this, data) ?: SideEffects.ReadOnly
}
override fun visitSetField(expression: IrSetField, data: Unit): SideEffects {
val valueEffect = expression.value.accept(this, data)
- if (valueEffect == SideEffects.READWRITE) return SideEffects.READWRITE
+ if (valueEffect == SideEffects.ReadWrite) return SideEffects.ReadWrite
- val constructorSymbol = callStack.lastOrNull() as? IrConstructorSymbol? ?: return SideEffects.READWRITE
+ val constructorSymbol = callStack.lastOrNull() as? IrConstructorSymbol? ?: return SideEffects.ReadWrite
if (expression.symbol.owner.origin == IrDeclarationOrigin.FIELD_FOR_OBJECT_INSTANCE) {
- return SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR
+ return SideEffects.AlmostPureSingletonConstructor(valueEffect)
}
// If we are in a constructor, and we're setting a constructed instance's field, treat it as a READNONE operation.
- val receiver = expression.receiver as? IrGetValue ?: return SideEffects.READWRITE
- val valueParameter = receiver.symbol.owner as? IrValueParameter ?: return SideEffects.READWRITE
- if (!valueParameter.isDispatchReceiver) return SideEffects.READWRITE
+ val receiver = expression.receiver as? IrGetValue ?: return SideEffects.ReadWrite
+ val valueParameter = receiver.symbol.owner as? IrValueParameter ?: return SideEffects.ReadWrite
+ if (!valueParameter.isDispatchReceiver) return SideEffects.ReadWrite
val assignmentEffect =
- if (valueParameter.parent == constructorSymbol.owner.constructedClass) SideEffects.READNONE else SideEffects.READWRITE
- return maxOf(valueEffect, assignmentEffect)
+ if (valueParameter.parent == constructorSymbol.owner.constructedClass) SideEffects.ReadNone else SideEffects.ReadWrite
+ return maxEffectOf(valueEffect, assignmentEffect)
}
override fun visitVararg(expression: IrVararg, data: Unit): SideEffects {
return expression.elements.maxEffect {
- (it as? IrExpression)?.accept(this, data) ?: SideEffects.READWRITE
+ (it as? IrExpression)?.accept(this, data) ?: SideEffects.ReadWrite
}
}
@@ -245,7 +291,7 @@
if (callStack.contains(function.symbol)) {
// Consider recursive calls non-pure
// A more precise analysis can be done, but for now we stick with this.
- return SideEffects.READWRITE
+ return SideEffects.ReadWrite
}
callStack.push(function.symbol)
@@ -253,13 +299,13 @@
try {
val functionSideEffects = function.computeEffectsImpl(this)
- if (functionSideEffects == SideEffects.READWRITE) return SideEffects.READWRITE
+ if (functionSideEffects == SideEffects.ReadWrite) return SideEffects.ReadWrite
val argComputationSideEffects = (0 until expression.valueArgumentsCount).maxEffect {
expression.getValueArgument(it)!!.accept(this, data)
}
- return maxOf(functionSideEffects, argComputationSideEffects)
+ return maxEffectOf(functionSideEffects, argComputationSideEffects)
} finally {
callStack.pop()
}
@@ -270,11 +316,11 @@
}
override fun visitBreakContinue(jump: IrBreakContinue, data: Unit): SideEffects {
- return SideEffects.READNONE
+ return SideEffects.ReadNone
}
override fun visitFunctionExpression(expression: IrFunctionExpression, data: Unit): SideEffects {
- return SideEffects.READNONE
+ return SideEffects.ReadNone
}
override fun visitWhen(expression: IrWhen, data: Unit): SideEffects {
diff --git a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt
index 91462c0..45de2b2 100644
--- a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt
+++ b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt
@@ -164,7 +164,7 @@
type = local.type
}
}
- outlinedFunction.addEffectsAnnotation(SideEffects.READWRITE, backendContext)
+ outlinedFunction.addEffectsAnnotation(SideEffects.ReadWrite, backendContext)
// Building JS Ast function
val lastStatement = jsStatements.findLast { it !is JsSingleLineComment && it !is JsMultiLineComment }
diff --git a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt
index c8af1d2..a21de50 100644
--- a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt
+++ b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt
@@ -11,7 +11,6 @@
import org.jetbrains.kotlin.backend.common.ir.computeEffects
import org.jetbrains.kotlin.ir.IrElement
import org.jetbrains.kotlin.ir.IrStatement
-import org.jetbrains.kotlin.ir.backend.js.JsIrBackendContext
import org.jetbrains.kotlin.ir.declarations.IrDeclaration
import org.jetbrains.kotlin.ir.expressions.*
import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
@@ -70,13 +69,10 @@
var unreachable = false
for (statement in statements) {
if (statement is IrFunctionAccessExpression) {
- val functionSideEffect = statement.symbol.owner.computeEffects(true, functionSideEffectMemoizer, context).let {
- if (it == SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR && statement !is IrConstructorCall)
- SideEffects.READNONE
- else it
- }
+ val functionSideEffect = statement.symbol.owner.computeEffects(true, functionSideEffectMemoizer, context)
- if (functionSideEffect <= SideEffects.READONLY) {
+ if (functionSideEffect.isAtMost(SideEffects.ReadOnly)) {
+ // Eliminate the call but keep the arguments, they can have effects.
for (i in 0 until statement.valueArgumentsCount) {
add(statement.getValueArgument(i)!!)
}
@@ -89,7 +85,7 @@
statement is IrExpression && (statement.computeEffects(
true,
functionSideEffectMemoizer
- ) <= SideEffects.READONLY) -> false
+ ).isAtMost(SideEffects.ReadOnly)) -> false
unreachable -> false
else -> {
unreachable = statement.doesNotReturn()