fix: don't remove arrow from lambdas that are when/if leaf nodes (#2758)

Co-authored-by: Paul Dingemans <paul-dingemans@users.noreply.github.com>
diff --git a/ktlint-ruleset-standard/src/main/kotlin/com/pinterest/ktlint/ruleset/standard/rules/FunctionLiteralRule.kt b/ktlint-ruleset-standard/src/main/kotlin/com/pinterest/ktlint/ruleset/standard/rules/FunctionLiteralRule.kt
index e183e28..165da85 100644
--- a/ktlint-ruleset-standard/src/main/kotlin/com/pinterest/ktlint/ruleset/standard/rules/FunctionLiteralRule.kt
+++ b/ktlint-ruleset-standard/src/main/kotlin/com/pinterest/ktlint/ruleset/standard/rules/FunctionLiteralRule.kt
@@ -4,13 +4,16 @@
 import com.pinterest.ktlint.rule.engine.core.api.AutocorrectDecision
 import com.pinterest.ktlint.rule.engine.core.api.ElementType.ARROW
 import com.pinterest.ktlint.rule.engine.core.api.ElementType.BLOCK
+import com.pinterest.ktlint.rule.engine.core.api.ElementType.ELSE
 import com.pinterest.ktlint.rule.engine.core.api.ElementType.FUNCTION_LITERAL
 import com.pinterest.ktlint.rule.engine.core.api.ElementType.LAMBDA_ARGUMENT
 import com.pinterest.ktlint.rule.engine.core.api.ElementType.LAMBDA_EXPRESSION
 import com.pinterest.ktlint.rule.engine.core.api.ElementType.LBRACE
 import com.pinterest.ktlint.rule.engine.core.api.ElementType.RBRACE
+import com.pinterest.ktlint.rule.engine.core.api.ElementType.THEN
 import com.pinterest.ktlint.rule.engine.core.api.ElementType.VALUE_PARAMETER
 import com.pinterest.ktlint.rule.engine.core.api.ElementType.VALUE_PARAMETER_LIST
+import com.pinterest.ktlint.rule.engine.core.api.ElementType.WHEN_ENTRY
 import com.pinterest.ktlint.rule.engine.core.api.IndentConfig
 import com.pinterest.ktlint.rule.engine.core.api.Rule.VisitorModifier.RunAfterRule.Mode.REGARDLESS_WHETHER_RUN_AFTER_RULE_IS_LOADED_OR_DISABLED
 import com.pinterest.ktlint.rule.engine.core.api.RuleId
@@ -366,7 +369,9 @@
         require(arrow.elementType == ARROW)
         arrow
             .prevSibling { it.elementType == VALUE_PARAMETER_LIST }
-            ?.takeIf { it.findChildByType(VALUE_PARAMETER) == null && arrow.isFollowedByNonEmptyBlock() }
+            ?.takeIf { it.hasEmptyParameterList() }
+            ?.takeUnless { arrow.isLambdaExpressionNotWrappedInBlock() }
+            ?.takeIf { arrow.isFollowedByNonEmptyBlock() }
             ?.let {
                 emit(arrow.startOffset, "Arrow is redundant when parameter list is empty", true)
                     .ifAutocorrectAllowed {
@@ -379,6 +384,29 @@
             }
     }
 
+    private fun ASTNode.hasEmptyParameterList(): Boolean {
+        require(elementType == VALUE_PARAMETER_LIST)
+        return findChildByType(VALUE_PARAMETER) == null
+    }
+
+    private fun ASTNode.isLambdaExpressionNotWrappedInBlock(): Boolean {
+        require(elementType == ARROW)
+        return parent(LAMBDA_EXPRESSION)
+            ?.treeParent
+            ?.elementType
+            ?.let { parentElementType ->
+                // Allow:
+                //     val foo = when {
+                //         1 == 2 -> { -> "hi" }
+                //         else -> { -> "ho" }
+                //     }
+                // or
+                //     val foo = if (cond) { -> "hi" } else { -> "ho" } parent ->
+                parentElementType == WHEN_ENTRY || parentElementType == THEN || parentElementType == ELSE
+            }
+            ?: false
+    }
+
     private fun ASTNode.isFollowedByNonEmptyBlock(): Boolean {
         require(elementType == ARROW)
         return nextSibling { it.elementType == BLOCK }?.firstChildNode != null
diff --git a/ktlint-ruleset-standard/src/test/kotlin/com/pinterest/ktlint/ruleset/standard/rules/FunctionLiteralRuleTest.kt b/ktlint-ruleset-standard/src/test/kotlin/com/pinterest/ktlint/ruleset/standard/rules/FunctionLiteralRuleTest.kt
index 26a809d..58c3ff8 100644
--- a/ktlint-ruleset-standard/src/test/kotlin/com/pinterest/ktlint/ruleset/standard/rules/FunctionLiteralRuleTest.kt
+++ b/ktlint-ruleset-standard/src/test/kotlin/com/pinterest/ktlint/ruleset/standard/rules/FunctionLiteralRuleTest.kt
@@ -506,4 +506,76 @@
             """.trimIndent()
         functionLiteralRuleAssertThat(code).hasNoLintViolations()
     }
+
+    @Test
+    fun `Issue 2758 - Given function literal with an arrow without parameters arrow literal as leaf of when then do not remove the arrow`() {
+        val code =
+            """
+            val foo =
+                when {
+                    false -> { -> "bar" }
+                    else -> { -> "baz" }
+                }
+            """.trimIndent()
+        functionLiteralRuleAssertThat(code).hasNoLintViolations()
+    }
+
+    @Test
+    fun `Issue 2758 - Given function literal with an arrow without parameters arrow literal not as leaf of when then do remove the arrow`() {
+        val code =
+            """
+            val foo =
+                when {
+                    false -> { { -> "bar" } }
+                    else -> { { -> "baz" } }
+                }
+            """.trimIndent()
+        val formattedCode =
+            """
+            val foo =
+                when {
+                    false -> { { "bar" } }
+                    else -> { { "baz" } }
+                }
+            """.trimIndent()
+        functionLiteralRuleAssertThat(code)
+            .hasLintViolations(
+                LintViolation(3, 22, "Arrow is redundant when parameter list is empty"),
+                LintViolation(4, 21, "Arrow is redundant when parameter list is empty"),
+            ).isFormattedAs(formattedCode)
+    }
+
+    @Test
+    fun `Issue 2758 - Given function literal with an arrow without parameters arrow literal as leaf of if then do not remove the arrow`() {
+        val code =
+            """
+            val foo = if (cond) { -> "bar" } else { -> "baz" }
+            """.trimIndent()
+        functionLiteralRuleAssertThat(code).hasNoLintViolations()
+    }
+
+    @Test
+    fun `Issue 2758 - Given function literal with an arrow without parameters arrow literal not as leaf of if then do remove the arrow`() {
+        val code =
+            """
+            val foo = if (cond) {
+                { -> "bar" }
+            } else {
+                { -> "baz" }
+            }
+            """.trimIndent()
+        val formattedCode =
+            """
+            val foo = if (cond) {
+                { "bar" }
+            } else {
+                { "baz" }
+            }
+            """.trimIndent()
+        functionLiteralRuleAssertThat(code)
+            .hasLintViolations(
+                LintViolation(2, 7, "Arrow is redundant when parameter list is empty"),
+                LintViolation(4, 7, "Arrow is redundant when parameter list is empty"),
+            ).isFormattedAs(formattedCode)
+    }
 }