[IR] Rewrite captured type parameters in local classes to explicit TPs
diff --git a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/LocalDeclarationsLowering.kt b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/LocalDeclarationsLowering.kt
index f2beccf..82a4c14 100644
--- a/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/LocalDeclarationsLowering.kt
+++ b/compiler/ir/backend.common/src/org/jetbrains/kotlin/backend/common/lower/LocalDeclarationsLowering.kt
@@ -60,14 +60,6 @@
object BOUND_RECEIVER_PARAMETER : IrDeclarationOriginImpl("BOUND_RECEIVER_PARAMETER")
-/*
- Local functions raised in LocalDeclarationLowering continue to refer to
- type parameters no longer visible to them.
- We add new type parameters to their declarations, which
- makes JVM accept those declarations. The generated IR is still
- semantically incorrect (TODO: needs further fix), but code generation seems
- to proceed nevertheless.
-*/
class LocalDeclarationsLowering(
val context: CommonBackendContext,
val localNameSanitizer: (String) -> String = { it },
@@ -244,9 +236,9 @@
return typeRemapper.remapType(type)
}
- private fun LocalContext.remapTypes(body: IrBody) {
+ private fun LocalContext.remapTypes(element: IrElement) {
if (capturedTypeParameterToTypeParameter.isEmpty()) return
- body.remapTypes(typeRemapper)
+ element.remapTypes(typeRemapper)
}
private inner class LocalDeclarationsTransformer(
@@ -473,7 +465,7 @@
origin = expression.origin
).also {
it.fillArguments2(expression, newCallee)
- it.setLocalTypeArguments(oldCallee)
+ it.setLocalTypeArguments { localFunctions[oldCallee] }
it.copyTypeArgumentsFrom(expression, shift = typeParameters.size - expression.typeArgumentsCount)
it.copyAttributes(expression)
}
@@ -580,43 +572,46 @@
origin = oldCall.origin,
superQualifierSymbol = oldCall.superQualifierSymbol
).also {
- it.setLocalTypeArguments(oldCall.symbol.owner)
+ it.setLocalTypeArguments { localFunctions[oldCall.symbol.owner] }
it.copyTypeArgumentsFrom(oldCall, shift = newCallee.typeParameters.size - oldCall.typeArgumentsCount)
}
- private fun createNewCall(oldCall: IrConstructorCall, newCallee: IrConstructor) =
- IrConstructorCallImpl.fromSymbolOwner(
+ private fun createNewCall(oldCall: IrConstructorCall, newCallee: IrConstructor): IrConstructorCallImpl {
+ val localClass = newCallee.parentAsClass
+ return IrConstructorCallImpl.fromSymbolOwner(
oldCall.startOffset, oldCall.endOffset,
oldCall.type,
newCallee.symbol,
- newCallee.parentAsClass.typeParameters.size,
+ localClass.typeParameters.size,
oldCall.origin
).also {
- it.copyTypeArgumentsFrom(oldCall)
+ it.setLocalTypeArguments { localClasses[localClass] }
+ it.copyTypeArgumentsFrom(oldCall, shift = localClass.typeParameters.size - oldCall.typeArgumentsCount)
}
+ }
- private fun IrMemberAccessExpression<*>.setLocalTypeArguments(callee: IrFunction) {
- val context = localFunctions[callee] ?: return
+ private inline fun IrMemberAccessExpression<*>.setLocalTypeArguments(getContext: () -> LocalContext?) {
+ val context = getContext() ?: return
for ((outerTypeParameter, innerTypeParameter) in context.capturedTypeParameterToTypeParameter) {
putTypeArgument(innerTypeParameter.index, outerTypeParameter.defaultType) // TODO: remap default type!
}
}
private fun transformDeclarations() {
- localFunctions.values.forEach {
- createLiftedDeclaration(it)
- }
+ localFunctions.values.forEach(::createLiftedFunctionDeclaration)
- localClasses.values.forEach {
- it.declaration.visibility = visibilityPolicy.forClass(it.declaration, it.inInlineFunctionScope)
- it.closure.capturedValues.associateTo(it.capturedValueToField) { capturedValue ->
+ localClasses.values.forEach { localClassContext ->
+ val declaration = localClassContext.declaration
+ val closure = localClassContext.closure
+ declaration.visibility = visibilityPolicy.forClass(declaration, localClassContext.inInlineFunctionScope)
+ closure.capturedValues.associateTo(localClassContext.capturedValueToField) { capturedValue ->
capturedValue.owner to PotentiallyUnusedField()
}
+ handleCapturedTypeParameters(declaration, declaration, closure, localClassContext)
+ localClassContext.remapTypes(declaration)
}
- localClassConstructors.values.forEach {
- createTransformedConstructorDeclaration(it)
- }
+ localClassConstructors.values.forEach(::createTransformedConstructorDeclaration)
}
private fun suggestLocalName(declaration: IrDeclarationWithName): String {
@@ -654,7 +649,57 @@
Name.identifier(nameFromParents)
}
- private fun createLiftedDeclaration(localFunctionContext: LocalFunctionContext) {
+ /**
+ * Adds the type parameters captured by [oldDeclaration] to [newDeclaration].
+ *
+ * In other words, transforms
+ *
+ * ```kotlin
+ * fun <T> foo(t: T) {
+ * class A<S>(val p0: T, val p1: S)
+ *
+ * fun <S> bar(p0: T, p1: S) {}
+ * }
+ * ```
+ *
+ * into
+ *
+ * ```kotlin
+ * fun <T> foo(t: T) {
+ * class A<P, S>(val p0: P, val p1: S)
+ *
+ * fun <P, S> bar(p0: P, p1: S) {}
+ * }
+ * ```
+ *
+ * [newDeclaration] and [oldDeclaration] may be the same declaration.
+ */
+ private fun handleCapturedTypeParameters(
+ newDeclaration: IrTypeParametersContainer,
+ oldDeclaration: IrTypeParametersContainer,
+ closure: Closure,
+ localContext: LocalContext
+ ) {
+ // Make sure that if oldDeclaration === newDeclaration, we don't end up having duplicate TPs.
+ val existingTypeParameters = oldDeclaration.typeParameters
+ newDeclaration.typeParameters = emptyList()
+
+ val capturedTypeParameters = closure.capturedTypeParameters
+ val newTypeParameters = newDeclaration.copyTypeParameters(capturedTypeParameters)
+ localContext.capturedTypeParameterToTypeParameter.putAll(
+ capturedTypeParameters.zip(newTypeParameters)
+ )
+ newDeclaration.copyTypeParameters(existingTypeParameters, parameterMap = localContext.capturedTypeParameterToTypeParameter)
+ localContext.capturedTypeParameterToTypeParameter.putAll(
+ existingTypeParameters.zip(newDeclaration.typeParameters.drop(newTypeParameters.size))
+ )
+ // Type parameters of oldDeclaration may depend on captured type parameters, so deal with that after copying.
+ newDeclaration.typeParameters.drop(newTypeParameters.size).forEach { tp ->
+ tp.superTypes = tp.superTypes.map { localContext.remapType(it) }
+ }
+ }
+
+ private fun createLiftedFunctionDeclaration(localFunctionContext: LocalFunctionContext) {
val oldDeclaration = localFunctionContext.declaration
if (oldDeclaration.dispatchReceiverParameter != null) {
throw AssertionError("local functions must not have dispatch receiver")
@@ -665,7 +710,7 @@
val newName = generateNameForLiftedDeclaration(oldDeclaration, ownerParent)
// TODO: consider using fields to access the closure of enclosing class.
- val (capturedValues, capturedTypeParameters) = localFunctionContext.closure
+ val capturedValues = localFunctionContext.closure.capturedValues
val newDeclaration = context.irFactory.buildFun {
updateFrom(oldDeclaration)
@@ -676,18 +721,7 @@
localFunctionContext.transformedDeclaration = newDeclaration
- val newTypeParameters = newDeclaration.copyTypeParameters(capturedTypeParameters)
- localFunctionContext.capturedTypeParameterToTypeParameter.putAll(
- capturedTypeParameters.zip(newTypeParameters)
- )
- newDeclaration.copyTypeParametersFrom(oldDeclaration, parameterMap = localFunctionContext.capturedTypeParameterToTypeParameter)
- localFunctionContext.capturedTypeParameterToTypeParameter.putAll(
- oldDeclaration.typeParameters.zip(newDeclaration.typeParameters.drop(newTypeParameters.size))
- )
- // Type parameters of oldDeclaration may depend on captured type parameters, so deal with that after copying.
- newDeclaration.typeParameters.drop(newTypeParameters.size).forEach { tp ->
- tp.superTypes = tp.superTypes.map { localFunctionContext.remapType(it) }
- }
+ handleCapturedTypeParameters(newDeclaration, oldDeclaration, localFunctionContext.closure, localFunctionContext)
newDeclaration.parent = ownerParent
newDeclaration.returnType = localFunctionContext.remapType(oldDeclaration.returnType)
@@ -840,7 +874,11 @@
}
}
- private fun suggestNameForCapturedValue(declaration: IrValueDeclaration, usedNames: MutableSet<String>, isExplicitLocalFunction: Boolean = false): Name {
+ private fun suggestNameForCapturedValue(
+ declaration: IrValueDeclaration,
+ usedNames: MutableSet<String>,
+ isExplicitLocalFunction: Boolean = false
+ ): Name {
if (declaration is IrValueParameter) {
if (declaration.name.asString() == "<this>" && declaration.isDispatchReceiver()) {
return findFirstUnusedName("this\$0", usedNames) {
diff --git a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt
index 150514d..4272530 100644
--- a/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt
+++ b/compiler/ir/backend.js/src/org/jetbrains/kotlin/ir/backend/js/lower/InteropCallableReferenceLowering.kt
@@ -652,13 +652,13 @@
}
}
- factoryDeclaration.valueParameters = constructor.valueParameters.map { it.copyTo(factoryDeclaration) }
- factoryDeclaration.typeParameters = constructor.typeParameters.map {
+ factoryDeclaration.typeParameters = lambdaInfo.lambdaClass.typeParameters.map {
it.copyToWithoutSuperTypes(factoryDeclaration).also { tp ->
// TODO: make sure it is done well
tp.superTypes += it.superTypes
}
}
+ factoryDeclaration.valueParameters = constructor.valueParameters.map { it.copyTo(factoryDeclaration) }
factoryDeclaration.body = buildFactoryBody(factoryDeclaration, newDeclarations, lambdaInfo)
diff --git a/compiler/testData/codegen/box/closures/closureCapturingGenericParam.kt b/compiler/testData/codegen/box/closures/closureCapturingGenericParam.kt
index cfd6985..654c52b 100644
--- a/compiler/testData/codegen/box/closures/closureCapturingGenericParam.kt
+++ b/compiler/testData/codegen/box/closures/closureCapturingGenericParam.kt
@@ -14,7 +14,32 @@
override fun toInt() = v
}
+interface Grouping<GroupingInputTP, out GroupingOutputTP> {
+ fun keyOf(element: GroupingInputTP): GroupingOutputTP
+}
+
+fun <GroupingByTP> groupingBy(keySelector: (Char) -> GroupingByTP): Grouping<Char, GroupingByTP> {
+
+ fun <T> foo(p0: T, p1: GroupingByTP) {}
+
+ foo(0, keySelector('a'))
+
+ class A<T>(p0: T, p1: GroupingByTP) {}
+
+ A(0, keySelector('a'))
+
+ return object : Grouping<Char, GroupingByTP> {
+ override fun keyOf(element: Char): GroupingByTP = keySelector(element)
+ }
+}
+
+class Delft<DelftTP> {
+ fun getComparator(other: DelftTP) = { this == other }
+}
+
fun box(): String {
- if (computeSum(arrayOf(N(2), N(14))) != 16) return "Fail"
+ if (computeSum(arrayOf(N(2), N(14))) != 16) return "Fail1"
+
+ if (groupingBy { it }.keyOf('A') != 'A') return "Fail2"
return "OK"
}