Effect analysis: a more clever handling of singleton initializers
diff --git a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt
index b08c748..13b4b55 100644
--- a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt
+++ b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/ir/SideEffects.kt
@@ -25,28 +25,47 @@
 import org.jetbrains.kotlin.name.FqName
 import org.jetbrains.kotlin.utils.addToStdlib.safeAs
 
-enum class SideEffects {
+sealed class SideEffects(val level: Int, val name: String) {
 
     /**
-     * Aka 'pure'. [READNONE] expressions can be reordered or eliminated.
+     * Aka 'pure'. [ReadNone] expressions can be reordered or eliminated.
      */
-    READNONE,
+    object ReadNone : SideEffects(0, "READNONE")
 
     /**
-     * May read the global state, may not alter it. Calls to [READONLY] functions cannot be reordered, but can be eliminated.
+     * May read the global state, may not alter it. Calls to [ReadOnly] functions cannot be reordered, but can be eliminated.
      */
-    READONLY,
+    object ReadOnly : SideEffects(1, "READONLY")
 
     /**
-     * A special kind of effect. Used to mark singleton initializers that are otherwise would be considered [READWRITE],
+     * A special kind of effect. Used to mark singleton initializers that are otherwise would be considered [ReadWrite],
      * but don't have any effects except saving the instance to a global variable.
+     *
+     * [otherwise] is the effects that a function would have if we removed all the [AlmostPureSingletonConstructor] effects from it.
      */
-    ALMOST_PURE_SINGLETON_CONSTRUCTOR,
+    data class AlmostPureSingletonConstructor(val otherwise: SideEffects) :
+        SideEffects(2 + otherwise.level, "ALMOST_PURE_SINGLETON_CONSTRUCTOR")
+    {
+        init {
+            require(otherwise !is AlmostPureSingletonConstructor)
+        }
+    }
 
     /**
      * Can arbitrarily alter the global state.
      */
-    READWRITE,
+    object ReadWrite : SideEffects(Int.MAX_VALUE, "READWRITE")
+
+    fun isAtMost(other: SideEffects) = level <= other.level
+
+    companion object {
+        fun valueOf(s: String) = when (s) {
+            ReadNone.name -> ReadNone
+            ReadOnly.name -> ReadOnly
+            ReadWrite.name -> ReadWrite
+            else -> throw IllegalArgumentException("$s is not a valid side effect name")
+        }
+    }
 }
 
 private val effectsAnnotationFqName = FqName("kotlin.internal.Effects")
@@ -61,8 +80,8 @@
 }
 
 fun IrFunction.addEffectsAnnotation(effects: SideEffects, context: CommonBackendContext) {
-    if (effects == SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR) {
-        error("${SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR.name} cannot be set via an annotation!")
+    if (effects is SideEffects.AlmostPureSingletonConstructor) {
+        error("${effects.name} cannot be set via an annotation!")
     }
     val annotationClassSymbol = context.getClassSymbol(effectsAnnotationFqName)
     val enumClassSymbol = context.getClassSymbol(effectEnumFqName)
@@ -83,7 +102,7 @@
     memoizer: FunctionSideEffectMemoizer = mutableMapOf(),
     context: CommonBackendContext? = null,
 ): SideEffects =
-    this?.accept(EffectAnalyzer(anyVariableReadIsPure, memoizer, context), Unit) ?: SideEffects.READNONE
+    this?.accept(EffectAnalyzer(anyVariableReadIsPure, memoizer, context), Unit) ?: SideEffects.ReadNone
 
 fun IrFunction.computeEffects(
     anyVariableReadIsPure: Boolean,
@@ -100,23 +119,50 @@
 ) = analyzer.memoizer.getOrPut(symbol) {
     if (analyzer.context != null && this is IrConstructor) {
         if (symbol.owner.constructedClass.symbol == analyzer.context.irBuiltIns.anyClass) {
-            return@getOrPut SideEffects.READNONE
+            return@getOrPut SideEffects.ReadNone
         }
     }
-    getDeclaredEffects()
+    val effects = getDeclaredEffects()
         ?: body?.accept(analyzer, Unit)
-        ?: SideEffects.READWRITE
+        ?: SideEffects.ReadWrite
+
+    // TODO: Good heuristic for Kotlin/JS, but will it work on other backends?
+    if (effects is SideEffects.AlmostPureSingletonConstructor && this !is IrConstructor)
+        return effects.otherwise
+
+    effects
 }
 
-fun IrExpression?.isPure(anyVariableReadIsPure: Boolean) = computeEffects(anyVariableReadIsPure) == SideEffects.READNONE
+fun IrExpression?.isPure(anyVariableReadIsPure: Boolean) = computeEffects(anyVariableReadIsPure) == SideEffects.ReadNone
 
 private inline fun <T> Iterable<T>.maxEffect(computeEffects: (T) -> SideEffects): SideEffects {
-    return maxOfOrNull {
-        computeEffects(it).also { result ->
-            // Early exit to avoid expensive computations if we already know that we're going to get READWRITE.
-            if (result == SideEffects.READWRITE) return SideEffects.READWRITE
-        }
-    } ?: SideEffects.READNONE
+    val iterator = iterator()
+    if (!iterator.hasNext()) return SideEffects.ReadNone
+    var maxValue = computeEffects(iterator.next())
+    while (iterator.hasNext()) {
+        if (maxValue is SideEffects.ReadWrite) return maxValue
+        val v = computeEffects(iterator.next())
+        maxValue = maxEffectOf(maxValue, v)
+    }
+    return maxValue
+}
+
+private fun maxEffectOf(a: SideEffects, b: SideEffects): SideEffects {
+    if (a is SideEffects.ReadWrite || b is SideEffects.ReadWrite) return SideEffects.ReadWrite
+
+    if (a is SideEffects.AlmostPureSingletonConstructor && b is SideEffects.AlmostPureSingletonConstructor) {
+        return SideEffects.AlmostPureSingletonConstructor(maxEffectOf(a.otherwise, b.otherwise))
+    }
+
+    if (a is SideEffects.AlmostPureSingletonConstructor) {
+        return if (a.otherwise.level < b.level) SideEffects.AlmostPureSingletonConstructor(b) else a
+    }
+
+    if (b is SideEffects.AlmostPureSingletonConstructor) {
+        return maxEffectOf(b, a)
+    }
+
+    return if (a.level >= b.level) a else b
 }
 
 private class EffectAnalyzer(
@@ -132,32 +178,32 @@
     }
 
     override fun visitElement(element: IrElement, data: Unit): SideEffects {
-        return SideEffects.READWRITE
+        return SideEffects.ReadWrite
     }
 
     override fun visitExpression(expression: IrExpression, data: Unit): SideEffects {
-        return SideEffects.READWRITE
+        return SideEffects.ReadWrite
     }
 
     override fun visitFunction(declaration: IrFunction, data: Unit): SideEffects {
         // Function declarations themselves have no effects.
-        return SideEffects.READNONE
+        return SideEffects.ReadNone
     }
 
     override fun visitClass(declaration: IrClass, data: Unit): SideEffects {
-        return SideEffects.READNONE
+        return SideEffects.ReadNone
     }
 
     override fun visitTypeAlias(declaration: IrTypeAlias, data: Unit): SideEffects {
-        return SideEffects.READNONE
+        return SideEffects.ReadNone
     }
 
     override fun visitVariable(declaration: IrVariable, data: Unit): SideEffects {
-        return declaration.initializer?.accept(this, data) ?: SideEffects.READNONE
+        return declaration.initializer?.accept(this, data) ?: SideEffects.ReadNone
     }
 
     override fun visitBody(body: IrBody, data: Unit): SideEffects {
-        return SideEffects.READWRITE
+        return SideEffects.ReadWrite
     }
 
     override fun visitBlockBody(body: IrBlockBody, data: Unit): SideEffects {
@@ -170,9 +216,9 @@
 
     override fun visitSyntheticBody(body: IrSyntheticBody, data: Unit): SideEffects {
         return when (body.kind) {
-            IrSyntheticBodyKind.ENUM_VALUES -> SideEffects.READNONE
-            IrSyntheticBodyKind.ENUM_VALUEOF -> SideEffects.READNONE
-            IrSyntheticBodyKind.ENUM_ENTRIES -> SideEffects.READNONE
+            IrSyntheticBodyKind.ENUM_VALUES -> SideEffects.ReadNone
+            IrSyntheticBodyKind.ENUM_VALUEOF -> SideEffects.ReadNone
+            IrSyntheticBodyKind.ENUM_ENTRIES -> SideEffects.ReadNone
         }
     }
 
@@ -185,57 +231,57 @@
     }
 
     override fun visitConst(expression: IrConst<*>, data: Unit): SideEffects {
-        return SideEffects.READNONE
+        return SideEffects.ReadNone
     }
 
     override fun visitGetValue(expression: IrGetValue, data: Unit): SideEffects {
-        if (anyVariableReadIsPure) return SideEffects.READNONE
+        if (anyVariableReadIsPure) return SideEffects.ReadNone
         val valueDeclaration = expression.symbol.owner
         val isPure = if (valueDeclaration is IrVariable) !valueDeclaration.isVar
         else true
-        return if (isPure) SideEffects.READNONE else SideEffects.READWRITE
+        return if (isPure) SideEffects.ReadNone else SideEffects.ReadWrite
     }
 
     override fun visitTypeOperator(expression: IrTypeOperatorCall, data: Unit): SideEffects {
         return if (expression.operator !in setOf(IrTypeOperator.INSTANCEOF, IrTypeOperator.REINTERPRET_CAST, IrTypeOperator.NOT_INSTANCEOF))
-            SideEffects.READWRITE
+            SideEffects.ReadWrite
         else
             expression.argument.computeEffects(anyVariableReadIsPure)
     }
 
     override fun visitGetObjectValue(expression: IrGetObjectValue, data: Unit): SideEffects {
-        return expression.symbol.owner.primaryConstructor?.accept(this, data) ?: SideEffects.READNONE
+        return expression.symbol.owner.primaryConstructor?.accept(this, data) ?: SideEffects.ReadNone
     }
 
     override fun visitGetField(expression: IrGetField, data: Unit): SideEffects {
         if (!expression.symbol.owner.isFinal && !anyVariableReadIsPure) {
-            return SideEffects.READWRITE
+            return SideEffects.ReadWrite
         }
-        return expression.receiver?.accept(this, data) ?: SideEffects.READONLY
+        return expression.receiver?.accept(this, data) ?: SideEffects.ReadOnly
     }
 
     override fun visitSetField(expression: IrSetField, data: Unit): SideEffects {
         val valueEffect = expression.value.accept(this, data)
-        if (valueEffect == SideEffects.READWRITE) return SideEffects.READWRITE
+        if (valueEffect == SideEffects.ReadWrite) return SideEffects.ReadWrite
 
-        val constructorSymbol = callStack.lastOrNull() as? IrConstructorSymbol? ?: return SideEffects.READWRITE
+        val constructorSymbol = callStack.lastOrNull() as? IrConstructorSymbol? ?: return SideEffects.ReadWrite
 
         if (expression.symbol.owner.origin == IrDeclarationOrigin.FIELD_FOR_OBJECT_INSTANCE) {
-            return SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR
+            return SideEffects.AlmostPureSingletonConstructor(valueEffect)
         }
 
         // If we are in a constructor, and we're setting a constructed instance's field, treat it as a READNONE operation.
-        val receiver = expression.receiver as? IrGetValue ?: return SideEffects.READWRITE
-        val valueParameter = receiver.symbol.owner as? IrValueParameter ?: return SideEffects.READWRITE
-        if (!valueParameter.isDispatchReceiver) return SideEffects.READWRITE
+        val receiver = expression.receiver as? IrGetValue ?: return SideEffects.ReadWrite
+        val valueParameter = receiver.symbol.owner as? IrValueParameter ?: return SideEffects.ReadWrite
+        if (!valueParameter.isDispatchReceiver) return SideEffects.ReadWrite
         val assignmentEffect =
-            if (valueParameter.parent == constructorSymbol.owner.constructedClass) SideEffects.READNONE else SideEffects.READWRITE
-        return maxOf(valueEffect, assignmentEffect)
+            if (valueParameter.parent == constructorSymbol.owner.constructedClass) SideEffects.ReadNone else SideEffects.ReadWrite
+        return maxEffectOf(valueEffect, assignmentEffect)
     }
 
     override fun visitVararg(expression: IrVararg, data: Unit): SideEffects {
         return expression.elements.maxEffect {
-            (it as? IrExpression)?.accept(this, data) ?: SideEffects.READWRITE
+            (it as? IrExpression)?.accept(this, data) ?: SideEffects.ReadWrite
         }
     }
 
@@ -245,7 +291,7 @@
         if (callStack.contains(function.symbol)) {
             // Consider recursive calls non-pure
             // A more precise analysis can be done, but for now we stick with this.
-            return SideEffects.READWRITE
+            return SideEffects.ReadWrite
         }
 
         callStack.push(function.symbol)
@@ -253,13 +299,13 @@
         try {
             val functionSideEffects = function.computeEffectsImpl(this)
 
-            if (functionSideEffects == SideEffects.READWRITE) return SideEffects.READWRITE
+            if (functionSideEffects == SideEffects.ReadWrite) return SideEffects.ReadWrite
 
             val argComputationSideEffects = (0 until expression.valueArgumentsCount).maxEffect {
                 expression.getValueArgument(it)!!.accept(this, data)
             }
 
-            return maxOf(functionSideEffects, argComputationSideEffects)
+            return maxEffectOf(functionSideEffects, argComputationSideEffects)
         } finally {
             callStack.pop()
         }
@@ -270,11 +316,11 @@
     }
 
     override fun visitBreakContinue(jump: IrBreakContinue, data: Unit): SideEffects {
-        return SideEffects.READNONE
+        return SideEffects.ReadNone
     }
 
     override fun visitFunctionExpression(expression: IrFunctionExpression, data: Unit): SideEffects {
-        return SideEffects.READNONE
+        return SideEffects.ReadNone
     }
 
     override fun visitWhen(expression: IrWhen, data: Unit): SideEffects {
diff --git a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt
index 91462c0..45de2b2 100644
--- a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt
+++ b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/JsCodeOutliningLowering.kt
@@ -164,7 +164,7 @@
                 type = local.type
             }
         }
-        outlinedFunction.addEffectsAnnotation(SideEffects.READWRITE, backendContext)
+        outlinedFunction.addEffectsAnnotation(SideEffects.ReadWrite, backendContext)
 
         // Building JS Ast function
         val lastStatement = jsStatements.findLast { it !is JsSingleLineComment && it !is JsMultiLineComment }
diff --git a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt
index c8af1d2..a21de50 100644
--- a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt
+++ b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/cleanup/CleanupLowering.kt
@@ -11,7 +11,6 @@
 import org.jetbrains.kotlin.backend.common.ir.computeEffects
 import org.jetbrains.kotlin.ir.IrElement
 import org.jetbrains.kotlin.ir.IrStatement
-import org.jetbrains.kotlin.ir.backend.js.JsIrBackendContext
 import org.jetbrains.kotlin.ir.declarations.IrDeclaration
 import org.jetbrains.kotlin.ir.expressions.*
 import org.jetbrains.kotlin.ir.symbols.IrFunctionSymbol
@@ -70,13 +69,10 @@
         var unreachable = false
         for (statement in statements) {
             if (statement is IrFunctionAccessExpression) {
-                val functionSideEffect = statement.symbol.owner.computeEffects(true, functionSideEffectMemoizer, context).let {
-                    if (it == SideEffects.ALMOST_PURE_SINGLETON_CONSTRUCTOR && statement !is IrConstructorCall)
-                        SideEffects.READNONE
-                    else it
-                }
+                val functionSideEffect = statement.symbol.owner.computeEffects(true, functionSideEffectMemoizer, context)
 
-                if (functionSideEffect <= SideEffects.READONLY) {
+                if (functionSideEffect.isAtMost(SideEffects.ReadOnly)) {
+                    // Eliminate the call but keep the arguments, they can have effects.
                     for (i in 0 until statement.valueArgumentsCount) {
                         add(statement.getValueArgument(i)!!)
                     }
@@ -89,7 +85,7 @@
                 statement is IrExpression && (statement.computeEffects(
                     true,
                     functionSideEffectMemoizer
-                ) <= SideEffects.READONLY) -> false
+                ).isAtMost(SideEffects.ReadOnly)) -> false
                 unreachable -> false
                 else -> {
                     unreachable = statement.doesNotReturn()