[Compose] Enable applier checking when using FIR Adds ComposableTarget checking when using the FIR front-end. Fixes: [282135108](https://issuetracker.google.com/282135108) Fixes: [349866442](https://issuetracker.google.com/349866442)
diff --git a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt index 257cd4b..e328a88 100644 --- a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt +++ b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt
@@ -313,4 +313,6 @@ else -> "declaration" } } + + val TO_STRING = Renderer { value: Any -> value.toString() } }
diff --git a/plugins/compose/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableTargetCheckerTests.kt b/plugins/compose/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableTargetCheckerTests.kt index a262707..ada621c 100644 --- a/plugins/compose/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableTargetCheckerTests.kt +++ b/plugins/compose/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableTargetCheckerTests.kt
@@ -22,8 +22,7 @@ import org.junit.runner.RunWith import org.junit.runners.JUnit4 -@RunWith(JUnit4::class) -class ComposableTargetCheckerTests : AbstractComposeDiagnosticsTest(useFir = false) { +class ComposableTargetCheckerTests(useFir: Boolean) : AbstractComposeDiagnosticsTest(useFir) { @Test fun testExplicitTargetAnnotations() = check( """ @@ -174,7 +173,7 @@ @Composable @ComposableTarget("N") fun T() { - <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>() + <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()} } @Composable @@ -191,7 +190,7 @@ @Composable fun T() { N() - <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>() + <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()} } @Composable @@ -225,8 +224,8 @@ W { N() } - <!COMPOSE_APPLIER_PARAMETER_MISMATCH!>W<!> { - M() + ${psiParStart()}W${psiEnd()} { + ${firMisStart()}M()${firEnd()} } } """ @@ -254,14 +253,14 @@ } @Composable - fun OpenCustom(content: CustomComposable) { - content.call() + fun OpenCustom(oContent: CustomComposable) { + oContent.call() } @Composable - fun ClosedCustom(content: CustomComposable) { + fun ClosedCustom(cContent: CustomComposable) { N() - content.call() + cContent.call() } @Composable @@ -285,16 +284,16 @@ OpenCustom { N() } - <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>() + <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()} } @Composable fun ClosedDisagree() { ClosedCustom { N() - <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>() + <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()} } - <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>() + <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()} } """ ) @@ -312,7 +311,7 @@ @Composable fun AssumesN() { - <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>() + <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()} } """ ) @@ -340,7 +339,7 @@ @Composable @NComposable fun AssumesN() { - <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>() + <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()} } """ ) @@ -369,7 +368,7 @@ @Composable fun AssumesN() { - <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>() + <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()} } """ ) @@ -387,7 +386,7 @@ @Composable fun UseText() { BasicText("Some text") - <!COMPOSE_APPLIER_CALL_MISMATCH!>Invalid<!>() + <!COMPOSE_APPLIER_CALL_MISMATCH!>Invalid${psiEnd()}()${firEnd()} } """, additionalPaths = listOf( @@ -413,7 +412,7 @@ class Invalid : Base() { @Composable override fun Compose() { - <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>() + <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()} } } @@ -439,7 +438,7 @@ } class Invalid : Base() { - <!COMPOSE_APPLIER_DECLARATION_MISMATCH!>@Composable @ComposableTarget("M") override fun Compose() { }<!> + ${psiDecStart()}@Composable @ComposableTarget("M") override fun ${firDecStart()}Compose${firEnd()}() { }${psiEnd()} } class Valid : Base () { @@ -469,4 +468,11 @@ } } """) + + private fun firEnd() = if (useFir) "<!>" else "" + private fun psiEnd() = if (!useFir) "<!>" else "" + private fun firMisStart() = if (useFir) "<!COMPOSE_APPLIER_CALL_MISMATCH!>" else "" + private fun psiParStart() = if (!useFir) "<!COMPOSE_APPLIER_PARAMETER_MISMATCH!>" else "" + private fun firDecStart() = if (useFir) "<!COMPOSE_APPLIER_DECLARATION_MISMATCH!>" else "" + private fun psiDecStart() = if (!useFir) "<!COMPOSE_APPLIER_DECLARATION_MISMATCH!>" else "" }
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeFqNames.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeFqNames.kt index 337948c..c6d4885 100644 --- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeFqNames.kt +++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeFqNames.kt
@@ -44,6 +44,7 @@ val ComposableLambda = internalClassIdFor("ComposableLambda") val ComposableOpenTarget = classIdFor("ComposableOpenTarget") val ComposableTarget = classIdFor("ComposableTarget") + val ComposableTargetMarker = classIdFor("ComposableTargetMarker") val ComposeVersion = classIdFor("ComposeVersion") val Composer = classIdFor("Composer") val DisallowComposableCalls = classIdFor("DisallowComposableCalls") @@ -102,6 +103,7 @@ val ComposableTarget = ComposeClassIds.ComposableTarget.asSingleFqName() val ComposableTargetMarker = fqNameFor("ComposableTargetMarker") val ComposableTargetMarkerDescription = "description" + val ComposableTargetMarkerDescriptionName = Name.identifier(ComposableTargetMarkerDescription) val ComposableTargetApplierArgument = Name.identifier("applier") val ComposableOpenTarget = ComposeClassIds.ComposableOpenTarget.asSingleFqName() val ComposableOpenTargetIndexArgument = Name.identifier("index")
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/WeakBindingTrace.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/WeakBindingTrace.kt index 2908760..09b9056 100644 --- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/WeakBindingTrace.kt +++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/WeakBindingTrace.kt
@@ -19,7 +19,9 @@ import com.intellij.util.keyFMap.KeyFMap import java.util.WeakHashMap import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext +import org.jetbrains.kotlin.fir.FirElement import org.jetbrains.kotlin.ir.declarations.IrAttributeContainer +import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext import org.jetbrains.kotlin.util.slicedMap.ReadOnlySlice import org.jetbrains.kotlin.util.slicedMap.WritableSlice import java.util.Collections @@ -46,9 +48,37 @@ operator fun <K : IrAttributeContainer, V> get(slice: ReadOnlySlice<K, V>, key: K): V? { return map[key.attributeOwnerId]?.get(slice.key) } + + + fun <K : FirElement, V> record(slice: WritableSlice<K, V>, key: K, value: V) { + recordAny(slice, key, value) + } + + operator fun <K : FirElement, V> get(slice: ReadOnlySlice<K, V>, key: K): V? = getAny(slice, key) + + fun <K : CheckerContext, V> record(slice: WritableSlice<K, V>, key: K, value: V) { + recordAny(slice, key, value) + } + + operator fun <K : CheckerContext, V> get(slice: ReadOnlySlice<K, V>, key: K): V? = getAny(slice, key) + + private fun <K : Any, V> getAny(slice: ReadOnlySlice<K, V>, key: K): V? = map[key]?.get(slice.key) + private fun <K : Any, V> recordAny(slice: WritableSlice<K, V>, key: K, value: V) { + var holder = map[key] ?: KeyFMap.EMPTY_MAP + val prev = holder.get(slice.key) + if (prev != null) holder = holder.minus(slice.key) + holder = holder.plus(slice.key, value!!) + map[key] = holder + } } private val ComposeTemporaryGlobalBindingTrace = WeakBindingTrace() @Suppress("unused") val IrPluginContext.irTrace: WeakBindingTrace get() = ComposeTemporaryGlobalBindingTrace + +@Suppress("unused") +val CheckerContext.firTrace: WeakBindingTrace get() = ComposeTemporaryGlobalBindingTrace + +@Suppress("unused") +val FirElement.firTrace: WeakBindingTrace get() = ComposeTemporaryGlobalBindingTrace
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/inference/ApplierInferencer.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/inference/ApplierInferencer.kt index 4253afb..0f45e97 100644 --- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/inference/ApplierInferencer.kt +++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/inference/ApplierInferencer.kt
@@ -164,7 +164,7 @@ * back-end IR nodes as well as allows for easier testing and debugging of the itself algorithm * without requiring either tree. */ -class ApplierInferencer<Type, Node>( +open class ApplierInferencer<Type, Node>( private val typeAdapter: TypeAdapter<Type>, private val nodeAdapter: NodeAdapter<Type, Node>, private val lazySchemeStorage: LazySchemeStorage<Node>,
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableCallChecker.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableCallChecker.kt index f74024d..05d8d85 100644 --- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableCallChecker.kt +++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableCallChecker.kt
@@ -62,7 +62,7 @@ ) { val calleeFunction = expression.calleeReference.toResolvedCallableSymbol() ?: return - if (calleeFunction.isComposable(context.session)) { + if (calleeFunction.isComposable(context)) { checkComposableCall(expression, calleeFunction, context, reporter) } } @@ -81,7 +81,7 @@ // https://youtrack.jetbrains.com/issue/KT-47708. if (calleeFunction.origin == FirDeclarationOrigin.SamConstructor) return - if (calleeFunction.isComposable(context.session)) { + if (calleeFunction.isComposable(context)) { checkComposableCall(expression, calleeFunction, context, reporter) } else if (calleeFunction.callableId.isInvoke()) { checkInvoke(expression, context, reporter)
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableFunctionChecker.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableFunctionChecker.kt index 3c19708..b845b45 100644 --- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableFunctionChecker.kt +++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableFunctionChecker.kt
@@ -41,16 +41,20 @@ // Check overrides for mismatched composable annotations for (override in declaration.getDirectOverriddenFunctions(context)) { - if (override.isComposable(context.session) != isComposable) { + if (override.isComposable(context) != isComposable) { reporter.reportOn( declaration.source, FirErrors.CONFLICTING_OVERLOADS, listOf(declaration.symbol, override), context ) + } else if (override.isComposable(context) && !override.toScheme(context).canOverride(declaration.symbol.toScheme(context))) { + reporter.reportOn( + source = declaration.source, + factory = ComposeErrors.COMPOSE_APPLIER_DECLARATION_MISMATCH, + context = context + ) } - - // TODO(b/282135108): Check scheme of override against declaration } // Check that `actual` composable declarations have composable expects
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableTargetChecker.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableTargetChecker.kt new file mode 100644 index 0000000..782ca7e --- /dev/null +++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableTargetChecker.kt
@@ -0,0 +1,356 @@ +/* + * Copyright 2010-2024 JetBrains s.r.o. and Kotlin Programming Language contributors. + * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file. + */ + +package androidx.compose.compiler.plugins.kotlin.k2 + +import androidx.compose.compiler.plugins.kotlin.ComposeClassIds +import androidx.compose.compiler.plugins.kotlin.ComposeFqNames +import androidx.compose.compiler.plugins.kotlin.firTrace +import androidx.compose.compiler.plugins.kotlin.inference.* +import org.jetbrains.kotlin.diagnostics.DiagnosticReporter +import org.jetbrains.kotlin.diagnostics.reportOn +import org.jetbrains.kotlin.fir.FirAnnotationContainer +import org.jetbrains.kotlin.fir.FirElement +import org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind +import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext +import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirFunctionCallChecker +import org.jetbrains.kotlin.fir.declarations.FirAnonymousFunction +import org.jetbrains.kotlin.fir.declarations.FirFunction +import org.jetbrains.kotlin.fir.declarations.getAnnotationByClassId +import org.jetbrains.kotlin.fir.expressions.* +import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol +import org.jetbrains.kotlin.fir.resolve.toClassSymbol +import org.jetbrains.kotlin.fir.scopes.impl.overrides +import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol +import org.jetbrains.kotlin.fir.symbols.SymbolInternals +import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol +import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol +import org.jetbrains.kotlin.fir.symbols.impl.FirValueParameterSymbol +import org.jetbrains.kotlin.fir.types.ConeKotlinType +import org.jetbrains.kotlin.fir.types.isPrimitive +import org.jetbrains.kotlin.fir.types.isString +import org.jetbrains.kotlin.fir.types.resolvedType +import org.jetbrains.kotlin.name.ClassId +import org.jetbrains.kotlin.name.Name +import org.jetbrains.kotlin.util.slicedMap.BasicWritableSlice +import org.jetbrains.kotlin.util.slicedMap.RewritePolicy +import org.jetbrains.kotlin.util.slicedMap.WritableSlice + +private sealed class FirInferenceNode(val element: FirElement) { + open val kind: NodeKind get() = NodeKind.Expression + abstract val type: InferenceNodeType? + open val referenceContainer: FirInferenceNode? get() = null + open val parameterIndex: Int get() = -1 + override fun hashCode(): Int = 31 * element.hashCode() + override fun equals(other: Any?): Boolean = other is FirInferenceNode && other.element == element +} + +private open class FirElementInferenceNode(element: FirElement) : FirInferenceNode(element) { + override val type: InferenceNodeType? get() = null +} + +private class FirCallableElementInferenceNode(val callable: FirCallableSymbol<*>, element: FirElement) : FirElementInferenceNode(element) { + override val type: InferenceNodeType = InferenceCallableType(callable) + override fun toString(): String = "${callable.name.toString()}()@${element.source?.startOffset}" +} + +private class FirFunctionInferenceNode(val function: FirFunction) : FirInferenceNode(function) { + override val kind get() = NodeKind.Function + override val type = InferenceCallableType(function.symbol) + override fun toString(): String = function.symbol.name.toString() +} + +private class FirLambdaInferenceNode(val lambda: FirAnonymousFunctionExpression): FirElementInferenceNode(lambda) { + override val kind: NodeKind get() = NodeKind.Lambda + override val type = InferenceCallableType(lambda.anonymousFunction.symbol) + override fun toString() = "<lambda:${lambda.source?.startOffset}>" +} + +private class FirSamInferenceNode(context: CheckerContext, val sam: FirSamConversionExpression): FirElementInferenceNode(sam) { + override val kind: NodeKind get() = NodeKind.Lambda + override val type: InferenceNodeType? = + (sam.expression as? FirAnonymousFunctionExpression)?.let { + InferenceCallableType(it.anonymousFunction.symbol) + } + + override fun toString(): String = "<sam:${sam.source?.startOffset}>" +} + +private class FirParameterReferenceNode( + element: FirElement, + override val parameterIndex: Int, + override val referenceContainer: FirInferenceNode +) : FirElementInferenceNode(element) { + override val kind: NodeKind get() = NodeKind.ParameterReference + override fun toString(): String = "param:$parameterIndex" +} + +private fun callableInferenceNodeOf(expression: FirElement, callable: FirCallableSymbol<*>, context: CheckerContext) = + parameterInferenceNodeOrNull(expression, context) ?: (expression as? FirAnonymousFunction)?.let { + mapping[expression] + }?.let { + inferenceNodeOf(it, context) + } ?: FirCallableElementInferenceNode(callable, expression) + +private sealed class InferenceNodeType { + abstract fun toScheme(context: CheckerContext): Scheme + abstract fun isTypeFor(callable: FirCallableSymbol<*>): Boolean +} + +private class InferenceCallableType(val callable: FirCallableSymbol<*>) : InferenceNodeType() { + override fun toScheme(context: CheckerContext): Scheme = callable.toScheme(context) + override fun isTypeFor(callable: FirCallableSymbol<*>) = this.callable.callableId == callable.callableId + override fun hashCode(): Int = 31 * callable.callableId.hashCode() + override fun equals(other: Any?): Boolean = + other is InferenceCallableType && other.callable.callableId == callable.callableId +} + +fun FirCallableSymbol<*>.toScheme(context: CheckerContext): Scheme = + declaredScheme(context) ?: Scheme( + target = schemeItem(context).let { + // The item is unspecified see if the containing has an annotation we can use + if (it.isUnspecified) { + val target = fileScopeTarget(context) + if (target != null) return@let target + } + it + }, + parameters = parameters(context).map { it.toScheme(context) } + ).mergeWith(methodOverrides(context).map { it.toScheme(context) }) + +@OptIn(SymbolInternals::class) +fun FirCallableSymbol<*>.methodOverrides(context: CheckerContext) = (fir as? FirFunction)?.getDirectOverriddenFunctions(context) ?: emptyList() + +fun FirCallableSymbol<*>.parameters(context: CheckerContext): List<FirValueParameterSymbol> = + (this as? FirFunctionSymbol<*>)?.let { + valueParameterSymbols.filter { it.isComposable(context) } + } ?: emptyList() + +@OptIn(SymbolInternals::class) +private fun FirCallableSymbol<*>.fileScopeTarget(context: CheckerContext): Item? { + fun findFileScope(element: FirElement): Item? = + (element as? FirAnnotationContainer)?.compositionTarget(context)?.let { Token(it) } ?: element.parent?.let { findFileScope(it) } + return findFileScope(fir) +} + +fun FirCallableSymbol<*>.declaredScheme(context: CheckerContext) = + (annotationArgument( + context, + ComposeClassIds.ComposableInferredTarget, + ComposeFqNames.ComposableInferredTargetSchemeArgument + ) as? String)?.let { + deserializeScheme(it) + } + +fun FirCallableSymbol<*>.schemeItem(context: CheckerContext): Item { + val explicitTarget = compositionTarget(context) + val explicitOpen = compositionOpenTarget(context) + return when { + explicitTarget != null -> Token(explicitTarget) + explicitOpen != null -> Open(explicitOpen) + else -> Open(-1, isUnspecified = true) + } +} + +fun FirCallableSymbol<*>.compositionTarget(context: CheckerContext): String? = + annotationArgument(context, ComposeClassIds.ComposableTarget, ComposeFqNames.ComposableTargetApplierArgument) as? String ?: run { + annotations.firstNotNullOfOrNull { + it.resolvedType.targetName(context) + } + } + +fun ConeKotlinType.targetName(context: CheckerContext): String? = toClassSymbol(context.session)?.let { cls -> + cls.annotationArgument(context, ComposeClassIds.ComposableTargetMarker, ComposeFqNames.ComposableTargetMarkerDescriptionName)?.let { + if (it is String && it != "") { + it + } else cls.classId.asFqNameString() + } +} + +fun FirCallableSymbol<*>.compositionOpenTarget(context: CheckerContext): Int? = + annotationArgument(context, ComposeClassIds.ComposableOpenTarget, ComposeFqNames.ComposableOpenTargetIndexArgument) as? Int + +fun FirBasedSymbol<*>.annotationArgument(context: CheckerContext, classId: ClassId, argumentName: Name) = + getAnnotationByClassId(classId, context.session)?.argument(argumentName) + +fun FirAnnotationContainer.compositionTarget(context: CheckerContext): String? = + annotationArgument(context, ComposeClassIds.ComposableTarget, ComposeFqNames.ComposableTargetApplierArgument) as? String ?: run { + annotations.firstNotNullOfOrNull { + it.resolvedType.targetName(context) + } + } + +fun FirAnnotationContainer.annotationArgument(context: CheckerContext, classId: ClassId, argumentName: Name) = + getAnnotationByClassId(classId, context.session)?.argument(argumentName) + +fun FirAnnotation.argument(name: Name): Any? = argumentMapping.mapping[name]?.let { + if ((it.resolvedType.isString || it.resolvedType.isPrimitive) && it is FirLiteralExpression) + it.value + else null +} + +object ComposableTargetChecker : FirFunctionCallChecker(MppCheckerKind.Common) { + @OptIn(SymbolInternals::class) + override fun check(expression: FirFunctionCall, context: CheckerContext, reporter: DiagnosticReporter) { + val calleeFunction = expression.calleeReference.toResolvedCallableSymbol() + ?: return + if (calleeFunction.isComposable(context)) { + updateParents(context) + val infer = getInfer(context, reporter) + val call = inferenceNodeOf(expression, context) + val target = callableInferenceNodeOf(expression, calleeFunction, context) + val parameters = calleeFunction.parameters(context) + val argumentsMapping = expression.resolvedArgumentMapping + val arguments = parameters.mapNotNull { parameter -> + argumentsMapping?.firstNotNullOf { entry -> + if (entry.value == parameter.fir) + inferenceNodeOf(entry.key, context) + else null + } + } + infer.visitCall(call, target, arguments) + } + } +} + +private class FirApplierInference( + val context: CheckerContext, + var reporter: DiagnosticReporter +) : ApplierInferencer<InferenceNodeType, FirInferenceNode>( + typeAdapter = object : TypeAdapter<InferenceNodeType> { + override fun declaredSchemaOf(type: InferenceNodeType): Scheme = type.toScheme(context) + override fun currentInferredSchemeOf(type: InferenceNodeType): Scheme? = null + override fun updatedInferredScheme(type: InferenceNodeType, scheme: Scheme) {} + }, + nodeAdapter = object : NodeAdapter<InferenceNodeType, FirInferenceNode> { + override fun containerOf(node: FirInferenceNode): FirInferenceNode { + var current = node.element.parent + while (current != null) { + when (current) { + is FirFunction -> return inferenceNodeOf(current, context) + } + current = current.parent + } + return node + } + + override fun kindOf(node: FirInferenceNode): NodeKind = node.kind + + override fun schemeParameterIndexOf(node: FirInferenceNode, container: FirInferenceNode): Int = node.parameterIndex + + override fun typeOf(node: FirInferenceNode): InferenceNodeType? = node.type + + override fun referencedContainerOf(node: FirInferenceNode): FirInferenceNode? = node.referenceContainer + }, + errorReporter = object : ErrorReporter<FirInferenceNode> { + private fun descriptionFrom(token: String): String = token // TODO: find the message if appropriate + override fun reportCallError(node: FirInferenceNode, expected: String, received: String) { + if (expected != received) { + val expectedDescription = descriptionFrom(expected) + val receivedDescription = descriptionFrom(received) + reporter.reportOn( + source = node.element.source, + factory = ComposeErrors.COMPOSE_APPLIER_CALL_MISMATCH, + context = context, + a = expectedDescription, + b = receivedDescription + ) + } + } + + override fun reportParameterError(node: FirInferenceNode, index: Int, expected: String, received: String) { + reporter.reportOn( + source = node.element.source, + factory = ComposeErrors.COMPOSE_APPLIER_PARAMETER_MISMATCH, + context = context, + a = expected, + b = received + ) + } + + override fun log(node: FirInferenceNode?, message: String) { + } + }, + lazySchemeStorage = object : LazySchemeStorage<FirInferenceNode> { + override fun getLazyScheme(node: FirInferenceNode): LazyScheme? = + lazySchemes[node] + + override fun storeLazyScheme(node: FirInferenceNode, value: LazyScheme) { + lazySchemes[node] = value + } + } +) + +private var lazySchemes = mutableMapOf<FirInferenceNode, LazyScheme>() + +/** + * A map of elements that, for inference, needed to be treated as if they are identical such as lambdas + * and the anonymous function as well as sam conversions and the expression converted. + */ +private var mapping = mutableMapOf<FirElement, FirElement>() + +private fun inferenceNodeOf(element: FirElement, context: CheckerContext): FirInferenceNode = + element.firTrace[WritableSlices.NODE, element] ?: when (element) { + is FirAnonymousFunctionExpression -> run { + mapping[element.anonymousFunction] = element + FirLambdaInferenceNode(element) + } + is FirSamConversionExpression -> run { + (element.expression as? FirAnonymousFunctionExpression)?.let { + mapping[it.anonymousFunction] = element + } + FirSamInferenceNode(context, element) + } + is FirAnonymousFunction -> callableInferenceNodeOf(element, element.symbol, context) + is FirFunction -> + FirFunctionInferenceNode(element) + else -> parameterInferenceNodeOrNull(element, context) ?: FirElementInferenceNode(element) + }.also { + element.firTrace.record(WritableSlices.NODE, element, it) + } + + +@OptIn(SymbolInternals::class) +private fun parameterInferenceNodeOrNull(expression: FirElement, context: CheckerContext): FirInferenceNode? { + if (expression is FirFunctionCall) { + val receiver = expression.explicitReceiver as? FirQualifiedAccessExpression ?: return null + val parameterSymbol = receiver.toResolvedCallableSymbol() as? FirValueParameterSymbol ?: return null + val function = parameterSymbol.containingFunctionSymbol + val index = function.valueParameterSymbols.filter { it.isComposable(context) }.indexOf(parameterSymbol) + if (index >= 0) { + return FirParameterReferenceNode(expression, index, inferenceNodeOf(function.fir, context)) + } + } + return null +} + +private fun getInfer(context: CheckerContext, reporter: DiagnosticReporter): FirApplierInference { + return (context.firTrace[WritableSlices.INFER, context] ?: run { FirApplierInference(context, reporter) }).also { + it.reporter = reporter + } +} + +private fun updateParents(context: CheckerContext) { + val containingElements = context.containingElements + for (i in (1..<containingElements.size).reversed()) { + val element = containingElements[i] + if (element.parent != null) break + val parent = containingElements[i - 1] + element.firTrace.record(WritableSlices.PARENT, element, parent) + } +} + +private val FirElement.parent get() = firTrace[WritableSlices.PARENT, this] + +private object WritableSlices { + val INFER: WritableSlice<CheckerContext, FirApplierInference> = + BasicWritableSlice(RewritePolicy.DO_NOTHING) + val LAZY_SCHEME: WritableSlice<FirElement, LazyScheme> = + BasicWritableSlice(RewritePolicy.DO_NOTHING) + val NODE: WritableSlice<FirElement, FirInferenceNode> = + BasicWritableSlice(RewritePolicy.DO_NOTHING) + val PARENT: WritableSlice<FirElement, FirElement?> = + BasicWritableSlice(RewritePolicy.DO_NOTHING) +}
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrorMessages.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrorMessages.kt index af02ccd..7088454 100644 --- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrorMessages.kt +++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrorMessages.kt
@@ -99,5 +99,24 @@ ComposeErrors.MISMATCHED_COMPOSABLE_IN_EXPECT_ACTUAL, "Mismatched @Composable annotation between expect and actual declaration" ) + + map.put( + ComposeErrors.COMPOSE_APPLIER_CALL_MISMATCH, + "Calling a {1} composable function where a {0} composable was expected", + FirDiagnosticRenderers.TO_STRING, + FirDiagnosticRenderers.TO_STRING + ) + + map.put( + ComposeErrors.COMPOSE_APPLIER_PARAMETER_MISMATCH, + "A {1} composable parameter was provided where a {0} composable was expected", + FirDiagnosticRenderers.TO_STRING, + FirDiagnosticRenderers.TO_STRING + ) + + map.put( + ComposeErrors.COMPOSE_APPLIER_DECLARATION_MISMATCH, + "The composition target of an override must match the ancestor target" + ) } }
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrors.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrors.kt index 2c2ceaa..2e1c454 100644 --- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrors.kt +++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrors.kt
@@ -20,17 +20,7 @@ import com.intellij.openapi.util.TextRange import com.intellij.psi.PsiElement import com.intellij.util.diff.FlyweightCapableTreeStructure -import org.jetbrains.kotlin.diagnostics.LightTreePositioningStrategies -import org.jetbrains.kotlin.diagnostics.LightTreePositioningStrategy -import org.jetbrains.kotlin.diagnostics.PositioningStrategies -import org.jetbrains.kotlin.diagnostics.PositioningStrategy -import org.jetbrains.kotlin.diagnostics.SourceElementPositioningStrategies -import org.jetbrains.kotlin.diagnostics.SourceElementPositioningStrategy -import org.jetbrains.kotlin.diagnostics.error0 -import org.jetbrains.kotlin.diagnostics.error2 -import org.jetbrains.kotlin.diagnostics.error3 -import org.jetbrains.kotlin.diagnostics.findChildByType -import org.jetbrains.kotlin.diagnostics.markElement +import org.jetbrains.kotlin.diagnostics.* import org.jetbrains.kotlin.diagnostics.rendering.RootDiagnosticRendererFactory import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol import org.jetbrains.kotlin.fir.symbols.impl.FirValueParameterSymbol @@ -51,7 +41,7 @@ val NONREADONLY_CALL_IN_READONLY_COMPOSABLE by error0<PsiElement>() val CAPTURED_COMPOSABLE_INVOCATION by - error2<PsiElement, FirVariableSymbol<*>, FirCallableSymbol<*>>() + error2<PsiElement, FirVariableSymbol<*>, FirCallableSymbol<*>>() // composable calls are not allowed in try expressions // error goes on the `try` keyword @@ -60,10 +50,10 @@ ) val MISSING_DISALLOW_COMPOSABLE_CALLS_ANNOTATION by error3< - PsiElement, - FirValueParameterSymbol, // unmarked - FirValueParameterSymbol, // marked - FirCallableSymbol<*>>() + PsiElement, + FirValueParameterSymbol, // unmarked + FirValueParameterSymbol, // marked + FirCallableSymbol<*>>() val ABSTRACT_COMPOSABLE_DEFAULT_PARAMETER_VALUE by error0<PsiElement>() @@ -91,6 +81,14 @@ SourceElementPositioningStrategies.DECLARATION_NAME ) + val COMPOSE_APPLIER_CALL_MISMATCH by warning2<PsiElement, String, String>() + + val COMPOSE_APPLIER_PARAMETER_MISMATCH by warning2<PsiElement, String, String>() + + val COMPOSE_APPLIER_DECLARATION_MISMATCH by warning0<PsiElement>( + ComposeSourceElementPositioningStrategies.DECLARATION_NAME_OR_DEFAULT + ) + init { RootDiagnosticRendererFactory.registerFactory(ComposeErrorMessages) } @@ -105,20 +103,20 @@ } return PositioningStrategies.DEFAULT.mark(element) } - } + } private val LIGHT_TREE_TRY_KEYWORD: LightTreePositioningStrategy = object : LightTreePositioningStrategy() { - override fun mark( - node: LighterASTNode, - startOffset: Int, - endOffset: Int, - tree: FlyweightCapableTreeStructure<LighterASTNode> - ): List<TextRange> { - val target = tree.findChildByType(node, KtTokens.TRY_KEYWORD) ?: node - return markElement(target, startOffset, endOffset, tree, node) + override fun mark( + node: LighterASTNode, + startOffset: Int, + endOffset: Int, + tree: FlyweightCapableTreeStructure<LighterASTNode>, + ): List<TextRange> { + val target = tree.findChildByType(node, KtTokens.TRY_KEYWORD) ?: node + return markElement(target, startOffset, endOffset, tree, node) + } } - } private val PSI_DECLARATION_NAME_OR_DEFAULT: PositioningStrategy<PsiElement> = object : PositioningStrategy<PsiElement>() {
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeFirExtensions.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeFirExtensions.kt index ee5a0f4..700454c 100644 --- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeFirExtensions.kt +++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeFirExtensions.kt
@@ -103,7 +103,7 @@ override val expressionCheckers: ExpressionCheckers = object : ExpressionCheckers() { override val functionCallCheckers: Set<FirFunctionCallChecker> = - setOf(ComposableFunctionCallChecker) + setOf(ComposableFunctionCallChecker, ComposableTargetChecker) override val propertyAccessExpressionCheckers: Set<FirPropertyAccessExpressionChecker> = setOf(ComposablePropertyAccessExpressionChecker)
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/FirUtils.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/FirUtils.kt index 56d0cc0..c844504 100644 --- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/FirUtils.kt +++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/FirUtils.kt
@@ -17,39 +17,30 @@ package androidx.compose.compiler.plugins.kotlin.k2 import androidx.compose.compiler.plugins.kotlin.ComposeClassIds +import org.jetbrains.kotlin.descriptors.Modality import org.jetbrains.kotlin.fir.FirAnnotationContainer import org.jetbrains.kotlin.fir.FirSession import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext import org.jetbrains.kotlin.fir.analysis.checkers.getAnnotationStringParameter import org.jetbrains.kotlin.fir.analysis.checkers.unsubstitutedScope import org.jetbrains.kotlin.fir.containingClassLookupTag -import org.jetbrains.kotlin.fir.declarations.FirFunction -import org.jetbrains.kotlin.fir.declarations.FirPropertyAccessor -import org.jetbrains.kotlin.fir.declarations.FirResolvePhase -import org.jetbrains.kotlin.fir.declarations.hasAnnotation +import org.jetbrains.kotlin.fir.declarations.* import org.jetbrains.kotlin.fir.declarations.utils.isOverride +import org.jetbrains.kotlin.fir.declarations.utils.modality import org.jetbrains.kotlin.fir.expressions.FirFunctionCall import org.jetbrains.kotlin.fir.expressions.FirReturnExpression import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol +import org.jetbrains.kotlin.fir.resolve.toClassSymbol import org.jetbrains.kotlin.fir.resolve.toSymbol +import org.jetbrains.kotlin.fir.scopes.collectAllFunctions import org.jetbrains.kotlin.fir.scopes.getDirectOverriddenFunctions import org.jetbrains.kotlin.fir.scopes.getDirectOverriddenProperties +import org.jetbrains.kotlin.fir.scopes.unsubstitutedScope import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol import org.jetbrains.kotlin.fir.symbols.SymbolInternals -import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol -import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol -import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol -import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol -import org.jetbrains.kotlin.fir.symbols.impl.FirPropertyAccessorSymbol -import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol +import org.jetbrains.kotlin.fir.symbols.impl.* import org.jetbrains.kotlin.fir.symbols.lazyResolveToPhase -import org.jetbrains.kotlin.fir.types.ConeKotlinType -import org.jetbrains.kotlin.fir.types.ProjectionKind -import org.jetbrains.kotlin.fir.types.coneType -import org.jetbrains.kotlin.fir.types.isArrayType -import org.jetbrains.kotlin.fir.types.isString -import org.jetbrains.kotlin.fir.types.isUnit -import org.jetbrains.kotlin.fir.types.type +import org.jetbrains.kotlin.fir.types.* import org.jetbrains.kotlin.name.JvmStandardClassIds fun FirAnnotationContainer.hasComposableAnnotation(session: FirSession): Boolean = @@ -67,17 +58,36 @@ fun FirAnnotationContainer.hasDisallowComposableCallsAnnotation(session: FirSession): Boolean = hasAnnotation(ComposeClassIds.DisallowComposableCalls, session) -fun FirCallableSymbol<*>.isComposable(session: FirSession): Boolean = +fun FirAnnotationContainer.hasComposableTargetMarkerAnnotation(session: FirSession): Boolean = + hasAnnotation(ComposeClassIds.ComposableTargetMarker, session) + +fun FirCallableSymbol<*>.isComposable(context: CheckerContext): Boolean = when (this) { is FirFunctionSymbol<*> -> - hasComposableAnnotation(session) + hasComposableAnnotation(context.session) is FirPropertySymbol -> getterSymbol?.let { - it.hasComposableAnnotation(session) || it.isComposableDelegate(session) + it.hasComposableAnnotation(context.session) || it.isComposableDelegate(context) } ?: false + is FirValueParameterSymbol -> isComposable(context) else -> false } +private fun FirValueParameterSymbol.isComposable(context: CheckerContext): Boolean = + resolvedReturnType.customAnnotations.hasAnnotation(ComposeClassIds.Composable, context.session) || + findSamFunction(context)?.isComposable(context) == true + +private fun FirValueParameterSymbol.findSamFunction(context: CheckerContext): FirNamedFunctionSymbol? { + val type = resolvedReturnType + val session = context.session + val classSymbol = type.toClassSymbol(session) ?: return null + val samFunction = classSymbol + .unsubstitutedScope(session, context.scopeSession, withForcedTypeCalculator = true, memberRequiredPhase = null) + .collectAllFunctions() + .singleOrNull { it.modality == Modality.ABSTRACT } + return samFunction +} + fun FirCallableSymbol<*>.isReadOnlyComposable(session: FirSession): Boolean = when (this) { is FirFunctionSymbol<*> -> @@ -88,7 +98,7 @@ } @OptIn(SymbolInternals::class) -private fun FirPropertyAccessorSymbol.isComposableDelegate(session: FirSession): Boolean { +private fun FirPropertyAccessorSymbol.isComposableDelegate(context: CheckerContext): Boolean { if (!propertySymbol.hasDelegate) return false fir.lazyResolveToPhase(FirResolvePhase.BODY_RESOLVE) return ((fir @@ -98,7 +108,7 @@ ?.result as? FirFunctionCall) ?.calleeReference ?.toResolvedCallableSymbol() - ?.isComposable(session) + ?.isComposable(context) ?: false } @@ -173,5 +183,5 @@ private val FirFunctionSymbol<*>.explicitParameterTypes: List<ConeKotlinType> get() = resolvedContextReceivers.map { it.typeRef.coneType } + - listOfNotNull(receiverParameter?.typeRef?.coneType) + - valueParameterSymbols.map { it.resolvedReturnType } + listOfNotNull(receiverParameter?.typeRef?.coneType) + + valueParameterSymbols.map { it.resolvedReturnType }