JS: do not generate break for top level return during inline
diff --git a/js/js.inliner/src/org/jetbrains/kotlin/js/inline/FunctionInlineMutator.kt b/js/js.inliner/src/org/jetbrains/kotlin/js/inline/FunctionInlineMutator.kt index 2a14bb1..d82c548 100644 --- a/js/js.inliner/src/org/jetbrains/kotlin/js/inline/FunctionInlineMutator.kt +++ b/js/js.inliner/src/org/jetbrains/kotlin/js/inline/FunctionInlineMutator.kt
@@ -123,7 +123,7 @@ val returnOnTop = ContainerUtil.findInstance(body.getStatements(), javaClass<JsReturn>()) val hasReturnOnTopLevel = returnOnTop != null - val needBreakLabel = !(returnCount == 1 && hasReturnOnTopLevel) + val needBreakLabel = !(returnCount == 0 || returnCount == 1 && hasReturnOnTopLevel) var breakLabelRef: JsNameRef? = null if (needBreakLabel) {
diff --git a/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/collectUtils.kt b/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/collectUtils.kt index 2f6658d..e0af8c4 100644 --- a/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/collectUtils.kt +++ b/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/collectUtils.kt
@@ -68,8 +68,11 @@ return namedFunctions } -public fun collectInstances<T : JsNode>(klass: Class<T>, scope: JsNode): List<T> { - return with(InstanceCollector(klass)) { +kotlin.jvm.overloads +public fun collectInstances<T : JsNode>( + klass: Class<T>, scope: JsNode, visitNestedDeclarations: Boolean = false +): List<T> { + return with(InstanceCollector(klass, visitNestedDeclarations)) { accept(scope) collected }
diff --git a/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/collectors/InstanceCollector.kt b/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/collectors/InstanceCollector.kt index d6f882b..a7576a0 100644 --- a/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/collectors/InstanceCollector.kt +++ b/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/collectors/InstanceCollector.kt
@@ -16,13 +16,27 @@ package org.jetbrains.kotlin.js.inline.util.collectors +import com.google.dart.compiler.backend.js.ast.JsFunction import com.google.dart.compiler.backend.js.ast.JsNode +import com.google.dart.compiler.backend.js.ast.JsObjectLiteral import com.google.dart.compiler.backend.js.ast.RecursiveJsVisitor import java.util.ArrayList -class InstanceCollector<T : JsNode>(val klass: Class<T>) : RecursiveJsVisitor() { +class InstanceCollector<T : JsNode>(val klass: Class<T>, val visitNestedDeclarations: Boolean) : RecursiveJsVisitor() { public val collected: MutableList<T> = ArrayList() + override fun visitFunction(x: JsFunction) { + if (visitNestedDeclarations) { + visitElement(x) + } + } + + override fun visitObjectLiteral(x: JsObjectLiteral) { + if (visitNestedDeclarations) { + visitElement(x) + } + } + override fun visitElement(node: JsNode) { if (klass.isInstance(node)) { collected.add(klass.cast(node)!!)
diff --git a/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/rewriteUtils.kt b/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/rewriteUtils.kt index 1095deb..82bc66c 100644 --- a/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/rewriteUtils.kt +++ b/js/js.inliner/src/org/jetbrains/kotlin/js/inline/util/rewriteUtils.kt
@@ -16,10 +16,7 @@ package org.jetbrains.kotlin.js.inline.util -import com.google.dart.compiler.backend.js.ast.JsExpression -import com.google.dart.compiler.backend.js.ast.JsName -import com.google.dart.compiler.backend.js.ast.JsNameRef -import com.google.dart.compiler.backend.js.ast.JsNode +import com.google.dart.compiler.backend.js.ast.* import org.jetbrains.kotlin.js.inline.util.rewriters.NameReplacingVisitor import org.jetbrains.kotlin.js.inline.util.rewriters.ReturnReplacingVisitor import org.jetbrains.kotlin.js.inline.util.rewriters.ThisReplacingVisitor @@ -30,8 +27,20 @@ return NameReplacingVisitor(replaceMap).accept(node)!! } -public fun replaceReturns(scope: JsNode, resultRef: JsNameRef?, breakLabel: JsNameRef?): JsNode { - return ReturnReplacingVisitor(resultRef, breakLabel).accept(scope)!! +public fun replaceReturns(scope: JsBlock, resultRef: JsNameRef?, breakLabel: JsNameRef?): JsNode { + val visitor = ReturnReplacingVisitor(resultRef, breakLabel) + val withReturnReplaced = visitor.accept(scope)!! + + if (breakLabel != null) { + val statements = scope.getStatements() + val last = statements.last() as? JsBreak + + if (last?.getLabel()?.getName() === breakLabel.getName()) { + statements.remove(statements.lastIndex) + } + } + + return withReturnReplaced } public fun replaceThisReference<T : JsNode>(node: T, replacement: JsExpression) {
diff --git a/js/js.tests/test/org/jetbrains/kotlin/js/test/semantics/InlineSizeReductionTestGenerated.java b/js/js.tests/test/org/jetbrains/kotlin/js/test/semantics/InlineSizeReductionTestGenerated.java index 0f622f0..a456338 100644 --- a/js/js.tests/test/org/jetbrains/kotlin/js/test/semantics/InlineSizeReductionTestGenerated.java +++ b/js/js.tests/test/org/jetbrains/kotlin/js/test/semantics/InlineSizeReductionTestGenerated.java
@@ -35,6 +35,12 @@ JetTestUtils.assertAllTestsPresentByMetadata(this.getClass(), new File("js/js.translator/testData/inlineSizeReduction/cases"), Pattern.compile("^(.+)\\.kt$"), true); } + @TestMetadata("lastBreak.kt") + public void testLastBreak() throws Exception { + String fileName = JetTestUtils.navigationMetadata("js/js.translator/testData/inlineSizeReduction/cases/lastBreak.kt"); + doTest(fileName); + } + @TestMetadata("oneTopLevelReturn.kt") public void testOneTopLevelReturn() throws Exception { String fileName = JetTestUtils.navigationMetadata("js/js.translator/testData/inlineSizeReduction/cases/oneTopLevelReturn.kt");
diff --git a/js/js.tests/test/org/jetbrains/kotlin/js/test/utils/DirectiveTestUtils.java b/js/js.tests/test/org/jetbrains/kotlin/js/test/utils/DirectiveTestUtils.java index 6dd2986..a093a10 100644 --- a/js/js.tests/test/org/jetbrains/kotlin/js/test/utils/DirectiveTestUtils.java +++ b/js/js.tests/test/org/jetbrains/kotlin/js/test/utils/DirectiveTestUtils.java
@@ -104,8 +104,8 @@ String countStr = arguments.getNamedArgument("count"); int expectedCount = Integer.valueOf(countStr); - JsNode scope = AstSearchUtil.getFunction(ast, functionName); - List<T> nodes = collectInstances(klass, scope); + JsFunction function = AstSearchUtil.getFunction(ast, functionName); + List<T> nodes = collectInstances(klass, function.getBody()); int actualCount = 0; for (T node : nodes) { @@ -138,6 +138,8 @@ private static final DirectiveHandler COUNT_VARS = new CountNodesDirective<JsVars.JsVar>("CHECK_VARS_COUNT", JsVars.JsVar.class); + private static final DirectiveHandler COUNT_BREAKS = new CountNodesDirective<JsBreak>("CHECK_BREAKS_COUNT", JsBreak.class); + private static final DirectiveHandler HAS_INLINE_METADATA = new DirectiveHandler("CHECK_HAS_INLINE_METADATA") { @Override void processEntry(@NotNull JsNode ast, @NotNull ArgumentsHelper arguments) throws Exception { @@ -166,6 +168,7 @@ FUNCTIONS_HAVE_SAME_LINES, COUNT_LABELS, COUNT_VARS, + COUNT_BREAKS, HAS_INLINE_METADATA, HAS_NO_INLINE_METADATA );
diff --git a/js/js.translator/testData/inlineSizeReduction/cases/lastBreak.kt b/js/js.translator/testData/inlineSizeReduction/cases/lastBreak.kt new file mode 100644 index 0000000..6ed748b --- /dev/null +++ b/js/js.translator/testData/inlineSizeReduction/cases/lastBreak.kt
@@ -0,0 +1,37 @@ +package foo + +// CHECK_NOT_CALLED: f1 +// CHECK_NOT_CALLED: f2 +// CHECK_BREAKS_COUNT: function=test count=3 + +var even = arrayListOf<Int>() +var odd = arrayListOf<Int>() + +inline fun f2(x: Int): Unit { + if (x % 2 == 0) { + even.add(x) + return + } + + odd.add(x) + return +} + +inline fun f1(x: Boolean, y: Int, z: Int): Unit { + if (x) { + return f2(y) + } + + return f2(z) +} + +fun test(x: Boolean, y: Int, z: Int): Unit = f1(x, y, z) + +fun box(): String { + test(true, 2, 1) + test(false, 2, 1) + assertEquals(listOf(2), even) + assertEquals(listOf(1), odd) + + return "OK" +} \ No newline at end of file