[IR] Rewrite captured type parameters in local classes to explicit TPs
diff --git a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/LocalDeclarationsLowering.kt b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/LocalDeclarationsLowering.kt
index f2beccf..82a4c14 100644
--- a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/LocalDeclarationsLowering.kt
+++ b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/LocalDeclarationsLowering.kt
@@ -60,14 +60,6 @@
 
 object BOUND_RECEIVER_PARAMETER : IrDeclarationOriginImpl("BOUND_RECEIVER_PARAMETER")
 
-/*
-  Local functions raised in LocalDeclarationLowering continue to refer to
-  type parameters no longer visible to them.
-  We add new type parameters to their declarations, which
-  makes JVM accept those declarations. The generated IR is still
-  semantically incorrect (TODO: needs further fix), but code generation seems
-  to proceed nevertheless.
-*/
 class LocalDeclarationsLowering(
     val context: CommonBackendContext,
     val localNameSanitizer: (String) -> String = { it },
@@ -244,9 +236,9 @@
         return typeRemapper.remapType(type)
     }
 
-    private fun LocalContext.remapTypes(body: IrBody) {
+    private fun LocalContext.remapTypes(element: IrElement) {
         if (capturedTypeParameterToTypeParameter.isEmpty()) return
-        body.remapTypes(typeRemapper)
+        element.remapTypes(typeRemapper)
     }
 
     private inner class LocalDeclarationsTransformer(
@@ -473,7 +465,7 @@
                     origin = expression.origin
                 ).also {
                     it.fillArguments2(expression, newCallee)
-                    it.setLocalTypeArguments(oldCallee)
+                    it.setLocalTypeArguments { localFunctions[oldCallee] }
                     it.copyTypeArgumentsFrom(expression, shift = typeParameters.size - expression.typeArgumentsCount)
                     it.copyAttributes(expression)
                 }
@@ -580,43 +572,46 @@
                 origin = oldCall.origin,
                 superQualifierSymbol = oldCall.superQualifierSymbol
             ).also {
-                it.setLocalTypeArguments(oldCall.symbol.owner)
+                it.setLocalTypeArguments { localFunctions[oldCall.symbol.owner] }
                 it.copyTypeArgumentsFrom(oldCall, shift = newCallee.typeParameters.size - oldCall.typeArgumentsCount)
             }
 
-        private fun createNewCall(oldCall: IrConstructorCall, newCallee: IrConstructor) =
-            IrConstructorCallImpl.fromSymbolOwner(
+        private fun createNewCall(oldCall: IrConstructorCall, newCallee: IrConstructor): IrConstructorCallImpl {
+            val localClass = newCallee.parentAsClass
+            return IrConstructorCallImpl.fromSymbolOwner(
                 oldCall.startOffset, oldCall.endOffset,
                 oldCall.type,
                 newCallee.symbol,
-                newCallee.parentAsClass.typeParameters.size,
+                localClass.typeParameters.size,
                 oldCall.origin
             ).also {
-                it.copyTypeArgumentsFrom(oldCall)
+                it.setLocalTypeArguments { localClasses[localClass] }
+                it.copyTypeArgumentsFrom(oldCall, shift = localClass.typeParameters.size - oldCall.typeArgumentsCount)
             }
+        }
 
-        private fun IrMemberAccessExpression<*>.setLocalTypeArguments(callee: IrFunction) {
-            val context = localFunctions[callee] ?: return
+        private inline fun IrMemberAccessExpression<*>.setLocalTypeArguments(getContext: () -> LocalContext?) {
+            val context = getContext() ?: return
             for ((outerTypeParameter, innerTypeParameter) in context.capturedTypeParameterToTypeParameter) {
                 putTypeArgument(innerTypeParameter.index, outerTypeParameter.defaultType) // TODO: remap default type!
             }
         }
 
         private fun transformDeclarations() {
-            localFunctions.values.forEach {
-                createLiftedDeclaration(it)
-            }
+            localFunctions.values.forEach(::createLiftedFunctionDeclaration)
 
-            localClasses.values.forEach {
-                it.declaration.visibility = visibilityPolicy.forClass(it.declaration, it.inInlineFunctionScope)
-                it.closure.capturedValues.associateTo(it.capturedValueToField) { capturedValue ->
+            localClasses.values.forEach { localClassContext ->
+                val declaration = localClassContext.declaration
+                val closure = localClassContext.closure
+                declaration.visibility = visibilityPolicy.forClass(declaration, localClassContext.inInlineFunctionScope)
+                closure.capturedValues.associateTo(localClassContext.capturedValueToField) { capturedValue ->
                     capturedValue.owner to PotentiallyUnusedField()
                 }
+                handleCapturedTypeParameters(declaration, declaration, closure, localClassContext)
+                localClassContext.remapTypes(declaration)
             }
 
-            localClassConstructors.values.forEach {
-                createTransformedConstructorDeclaration(it)
-            }
+            localClassConstructors.values.forEach(::createTransformedConstructorDeclaration)
         }
 
         private fun suggestLocalName(declaration: IrDeclarationWithName): String {
@@ -654,7 +649,57 @@
                 Name.identifier(nameFromParents)
         }
 
-        private fun createLiftedDeclaration(localFunctionContext: LocalFunctionContext) {
+        /**
+         * Adds the type parameters captured by [oldDeclaration] to [newDeclaration].
+         *
+         * In other words, transforms
+         *
+         * ```kotlin
+         * fun <T> foo(t: T) {
+         *     class A<S>(val p0: T, val p1: S)
+         *
+         *     fun <S> bar(p0: T, p1: S) {}
+         * }
+         * ```
+         *
+         * into
+         *
+         * ```kotlin
+         * fun <T> foo(t: T) {
+         *     class A<P, S>(val p0: P, val p1: S)
+         *
+         *     fun <P, S> bar(p0: P, p1: S) {}
+         * }
+         * ```
+         *
+         * [newDeclaration] and [oldDeclaration] may be the same declaration.
+         */
+        private fun handleCapturedTypeParameters(
+            newDeclaration: IrTypeParametersContainer,
+            oldDeclaration: IrTypeParametersContainer,
+            closure: Closure,
+            localContext: LocalContext
+        ) {
+            // Make sure that if oldDeclaration === newDeclaration, we don't end up having duplicate TPs.
+            val existingTypeParameters = oldDeclaration.typeParameters
+            newDeclaration.typeParameters = emptyList()
+
+            val capturedTypeParameters = closure.capturedTypeParameters
+            val newTypeParameters = newDeclaration.copyTypeParameters(capturedTypeParameters)
+            localContext.capturedTypeParameterToTypeParameter.putAll(
+                capturedTypeParameters.zip(newTypeParameters)
+            )
+            newDeclaration.copyTypeParameters(existingTypeParameters, parameterMap = localContext.capturedTypeParameterToTypeParameter)
+            localContext.capturedTypeParameterToTypeParameter.putAll(
+                existingTypeParameters.zip(newDeclaration.typeParameters.drop(newTypeParameters.size))
+            )
+            // Type parameters of oldDeclaration may depend on captured type parameters, so deal with that after copying.
+            newDeclaration.typeParameters.drop(newTypeParameters.size).forEach { tp ->
+                tp.superTypes = tp.superTypes.map { localContext.remapType(it) }
+            }
+        }
+
+        private fun createLiftedFunctionDeclaration(localFunctionContext: LocalFunctionContext) {
             val oldDeclaration = localFunctionContext.declaration
             if (oldDeclaration.dispatchReceiverParameter != null) {
                 throw AssertionError("local functions must not have dispatch receiver")
@@ -665,7 +710,7 @@
             val newName = generateNameForLiftedDeclaration(oldDeclaration, ownerParent)
 
             // TODO: consider using fields to access the closure of enclosing class.
-            val (capturedValues, capturedTypeParameters) = localFunctionContext.closure
+            val capturedValues = localFunctionContext.closure.capturedValues
 
             val newDeclaration = context.irFactory.buildFun {
                 updateFrom(oldDeclaration)
@@ -676,18 +721,7 @@
 
             localFunctionContext.transformedDeclaration = newDeclaration
 
-            val newTypeParameters = newDeclaration.copyTypeParameters(capturedTypeParameters)
-            localFunctionContext.capturedTypeParameterToTypeParameter.putAll(
-                capturedTypeParameters.zip(newTypeParameters)
-            )
-            newDeclaration.copyTypeParametersFrom(oldDeclaration, parameterMap = localFunctionContext.capturedTypeParameterToTypeParameter)
-            localFunctionContext.capturedTypeParameterToTypeParameter.putAll(
-                oldDeclaration.typeParameters.zip(newDeclaration.typeParameters.drop(newTypeParameters.size))
-            )
-            // Type parameters of oldDeclaration may depend on captured type parameters, so deal with that after copying.
-            newDeclaration.typeParameters.drop(newTypeParameters.size).forEach { tp ->
-                tp.superTypes = tp.superTypes.map { localFunctionContext.remapType(it) }
-            }
+            handleCapturedTypeParameters(newDeclaration, oldDeclaration, localFunctionContext.closure, localFunctionContext)
 
             newDeclaration.parent = ownerParent
             newDeclaration.returnType = localFunctionContext.remapType(oldDeclaration.returnType)
@@ -840,7 +874,11 @@
             }
         }
 
-        private fun suggestNameForCapturedValue(declaration: IrValueDeclaration, usedNames: MutableSet<String>, isExplicitLocalFunction: Boolean = false): Name {
+        private fun suggestNameForCapturedValue(
+            declaration: IrValueDeclaration,
+            usedNames: MutableSet<String>,
+            isExplicitLocalFunction: Boolean = false
+        ): Name {
             if (declaration is IrValueParameter) {
                 if (declaration.name.asString() == "<this>" && declaration.isDispatchReceiver()) {
                     return findFirstUnusedName("this\$0", usedNames) {
diff --git a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt
index 150514d..4272530 100644
--- a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt
+++ b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt
@@ -652,13 +652,13 @@
             }
         }
 
-        factoryDeclaration.valueParameters = constructor.valueParameters.map { it.copyTo(factoryDeclaration) }
-        factoryDeclaration.typeParameters = constructor.typeParameters.map {
+        factoryDeclaration.typeParameters = lambdaInfo.lambdaClass.typeParameters.map {
             it.copyToWithoutSuperTypes(factoryDeclaration).also { tp ->
                 // TODO: make sure it is done well
                 tp.superTypes += it.superTypes
             }
         }
+        factoryDeclaration.valueParameters = constructor.valueParameters.map { it.copyTo(factoryDeclaration) }
 
         factoryDeclaration.body = buildFactoryBody(factoryDeclaration, newDeclarations, lambdaInfo)
 
diff --git a/compiler/testData/codegen/box/closures/closureCapturingGenericParam.kt b/compiler/testData/codegen/box/closures/closureCapturingGenericParam.kt
index cfd6985..654c52b 100644
--- a/compiler/testData/codegen/box/closures/closureCapturingGenericParam.kt
+++ b/compiler/testData/codegen/box/closures/closureCapturingGenericParam.kt
@@ -14,7 +14,32 @@
     override fun toInt() = v
 }
 
+interface Grouping<GroupingInputTP, out GroupingOutputTP> {
+    fun keyOf(element: GroupingInputTP): GroupingOutputTP
+}
+
+fun <GroupingByTP> groupingBy(keySelector: (Char) -> GroupingByTP): Grouping<Char, GroupingByTP> {
+
+    fun <T> foo(p0: T, p1: GroupingByTP) {}
+
+    foo(0, keySelector('a'))
+
+    class A<T>(p0: T, p1: GroupingByTP) {}
+
+    A(0, keySelector('a'))
+
+    return object : Grouping<Char, GroupingByTP> {
+        override fun keyOf(element: Char): GroupingByTP = keySelector(element)
+    }
+}
+
+class Delft<DelftTP> {
+    fun getComparator(other: DelftTP) = { this == other }
+}
+
 fun box(): String {
-    if (computeSum(arrayOf(N(2), N(14))) != 16) return "Fail"
+    if (computeSum(arrayOf(N(2), N(14))) != 16) return "Fail1"
+
+    if (groupingBy { it }.keyOf('A') != 'A') return "Fail2"
     return "OK"
 }