[Compose] Enable applier checking when using FIR

Adds ComposableTarget checking when using the FIR front-end.

Fixes: [282135108](https://issuetracker.google.com/282135108)
Fixes: [349866442](https://issuetracker.google.com/349866442)
diff --git a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt
index 257cd4b..e328a88 100644
--- a/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt
+++ b/compiler/fir/checkers/src/org/jetbrains/kotlin/fir/analysis/diagnostics/FirDiagnosticRenderers.kt
@@ -313,4 +313,6 @@
             else -> "declaration"
         }
     }
+
+    val TO_STRING = Renderer { value: Any -> value.toString() }
 }
diff --git a/plugins/compose/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableTargetCheckerTests.kt b/plugins/compose/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableTargetCheckerTests.kt
index a262707..ada621c 100644
--- a/plugins/compose/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableTargetCheckerTests.kt
+++ b/plugins/compose/compiler-hosted/integration-tests/src/jvmTest/kotlin/androidx/compose/compiler/plugins/kotlin/analysis/ComposableTargetCheckerTests.kt
@@ -22,8 +22,7 @@
 import org.junit.runner.RunWith
 import org.junit.runners.JUnit4
 
-@RunWith(JUnit4::class)
-class ComposableTargetCheckerTests : AbstractComposeDiagnosticsTest(useFir = false) {
+class ComposableTargetCheckerTests(useFir: Boolean) : AbstractComposeDiagnosticsTest(useFir) {
     @Test
     fun testExplicitTargetAnnotations() = check(
         """
@@ -174,7 +173,7 @@
         @Composable
         @ComposableTarget("N")
         fun T() {
-            <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>()
+            <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()}
         }
 
         @Composable
@@ -191,7 +190,7 @@
         @Composable
         fun T() {
             N()
-            <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>()
+            <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()}
         }
 
         @Composable
@@ -225,8 +224,8 @@
             W {
                 N()
             }
-            <!COMPOSE_APPLIER_PARAMETER_MISMATCH!>W<!> {
-                M()
+            ${psiParStart()}W${psiEnd()} {
+                ${firMisStart()}M()${firEnd()}
             }
         }
         """
@@ -254,14 +253,14 @@
         }
 
         @Composable
-        fun OpenCustom(content: CustomComposable) {
-          content.call()
+        fun OpenCustom(oContent: CustomComposable) {
+          oContent.call()
         }
 
         @Composable
-        fun ClosedCustom(content: CustomComposable) {
+        fun ClosedCustom(cContent: CustomComposable) {
           N()
-          content.call()
+          cContent.call()
         }
 
         @Composable
@@ -285,16 +284,16 @@
           OpenCustom {
             N()
           }
-          <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>()
+          <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()}
         }
 
         @Composable
         fun ClosedDisagree() {
           ClosedCustom {
             N()
-            <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>()
+            <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()}
           }
-          <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>()
+          <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()}
         }
         """
     )
@@ -312,7 +311,7 @@
 
         @Composable
         fun AssumesN() {
-            <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>()
+            <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()}
         }
         """
     )
@@ -340,7 +339,7 @@
         @Composable
         @NComposable
         fun AssumesN() {
-            <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>()
+            <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()}
         }
         """
     )
@@ -369,7 +368,7 @@
 
         @Composable
         fun AssumesN() {
-            <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>()
+            <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()}
         }
         """
     )
@@ -387,7 +386,7 @@
         @Composable
         fun UseText() {
            BasicText("Some text")
-           <!COMPOSE_APPLIER_CALL_MISMATCH!>Invalid<!>()
+           <!COMPOSE_APPLIER_CALL_MISMATCH!>Invalid${psiEnd()}()${firEnd()}
         }
         """,
         additionalPaths = listOf(
@@ -413,7 +412,7 @@
 
         class Invalid : Base() {
           @Composable override fun Compose() {
-            <!COMPOSE_APPLIER_CALL_MISMATCH!>M<!>()
+            <!COMPOSE_APPLIER_CALL_MISMATCH!>M${psiEnd()}()${firEnd()}
           }
         }
 
@@ -439,7 +438,7 @@
         }
 
         class Invalid : Base() {
-          <!COMPOSE_APPLIER_DECLARATION_MISMATCH!>@Composable @ComposableTarget("M") override fun Compose() { }<!>
+          ${psiDecStart()}@Composable @ComposableTarget("M") override fun ${firDecStart()}Compose${firEnd()}() { }${psiEnd()}
         }
 
         class Valid : Base () {
@@ -469,4 +468,11 @@
           }
         }
         """)
+
+    private fun firEnd() = if (useFir) "<!>" else ""
+    private fun psiEnd() = if (!useFir) "<!>" else ""
+    private fun firMisStart() = if (useFir) "<!COMPOSE_APPLIER_CALL_MISMATCH!>" else ""
+    private fun psiParStart() = if (!useFir) "<!COMPOSE_APPLIER_PARAMETER_MISMATCH!>" else ""
+    private fun firDecStart() = if (useFir) "<!COMPOSE_APPLIER_DECLARATION_MISMATCH!>" else ""
+    private fun psiDecStart() = if (!useFir) "<!COMPOSE_APPLIER_DECLARATION_MISMATCH!>" else ""
 }
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeFqNames.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeFqNames.kt
index 337948c..c6d4885 100644
--- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeFqNames.kt
+++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/ComposeFqNames.kt
@@ -44,6 +44,7 @@
     val ComposableLambda = internalClassIdFor("ComposableLambda")
     val ComposableOpenTarget = classIdFor("ComposableOpenTarget")
     val ComposableTarget = classIdFor("ComposableTarget")
+    val ComposableTargetMarker = classIdFor("ComposableTargetMarker")
     val ComposeVersion = classIdFor("ComposeVersion")
     val Composer = classIdFor("Composer")
     val DisallowComposableCalls = classIdFor("DisallowComposableCalls")
@@ -102,6 +103,7 @@
     val ComposableTarget = ComposeClassIds.ComposableTarget.asSingleFqName()
     val ComposableTargetMarker = fqNameFor("ComposableTargetMarker")
     val ComposableTargetMarkerDescription = "description"
+    val ComposableTargetMarkerDescriptionName = Name.identifier(ComposableTargetMarkerDescription)
     val ComposableTargetApplierArgument = Name.identifier("applier")
     val ComposableOpenTarget = ComposeClassIds.ComposableOpenTarget.asSingleFqName()
     val ComposableOpenTargetIndexArgument = Name.identifier("index")
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/WeakBindingTrace.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/WeakBindingTrace.kt
index 2908760..09b9056 100644
--- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/WeakBindingTrace.kt
+++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/WeakBindingTrace.kt
@@ -19,7 +19,9 @@
 import com.intellij.util.keyFMap.KeyFMap
 import java.util.WeakHashMap
 import org.jetbrains.kotlin.backend.common.extensions.IrPluginContext
+import org.jetbrains.kotlin.fir.FirElement
 import org.jetbrains.kotlin.ir.declarations.IrAttributeContainer
+import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
 import org.jetbrains.kotlin.util.slicedMap.ReadOnlySlice
 import org.jetbrains.kotlin.util.slicedMap.WritableSlice
 import java.util.Collections
@@ -46,9 +48,37 @@
     operator fun <K : IrAttributeContainer, V> get(slice: ReadOnlySlice<K, V>, key: K): V? {
         return map[key.attributeOwnerId]?.get(slice.key)
     }
+
+
+    fun <K : FirElement, V> record(slice: WritableSlice<K, V>, key: K, value: V) {
+        recordAny(slice, key, value)
+    }
+
+    operator fun <K : FirElement, V> get(slice: ReadOnlySlice<K, V>, key: K): V? = getAny(slice, key)
+
+    fun <K : CheckerContext, V> record(slice: WritableSlice<K, V>, key: K, value: V) {
+        recordAny(slice, key, value)
+    }
+
+    operator fun <K : CheckerContext, V> get(slice: ReadOnlySlice<K, V>, key: K): V? = getAny(slice, key)
+
+    private fun <K : Any, V> getAny(slice: ReadOnlySlice<K, V>, key: K): V? = map[key]?.get(slice.key)
+    private fun <K : Any, V> recordAny(slice: WritableSlice<K, V>, key: K, value: V) {
+        var holder = map[key] ?: KeyFMap.EMPTY_MAP
+        val prev = holder.get(slice.key)
+        if (prev != null) holder = holder.minus(slice.key)
+        holder = holder.plus(slice.key, value!!)
+        map[key] = holder
+    }
 }
 
 private val ComposeTemporaryGlobalBindingTrace = WeakBindingTrace()
 
 @Suppress("unused")
 val IrPluginContext.irTrace: WeakBindingTrace get() = ComposeTemporaryGlobalBindingTrace
+
+@Suppress("unused")
+val CheckerContext.firTrace: WeakBindingTrace get() = ComposeTemporaryGlobalBindingTrace
+
+@Suppress("unused")
+val FirElement.firTrace: WeakBindingTrace get() = ComposeTemporaryGlobalBindingTrace
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/inference/ApplierInferencer.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/inference/ApplierInferencer.kt
index 4253afb..0f45e97 100644
--- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/inference/ApplierInferencer.kt
+++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/inference/ApplierInferencer.kt
@@ -164,7 +164,7 @@
  * back-end IR nodes as well as allows for easier testing and debugging of the itself algorithm
  * without requiring either tree.
  */
-class ApplierInferencer<Type, Node>(
+open class ApplierInferencer<Type, Node>(
     private val typeAdapter: TypeAdapter<Type>,
     private val nodeAdapter: NodeAdapter<Type, Node>,
     private val lazySchemeStorage: LazySchemeStorage<Node>,
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableCallChecker.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableCallChecker.kt
index f74024d..05d8d85 100644
--- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableCallChecker.kt
+++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableCallChecker.kt
@@ -62,7 +62,7 @@
     ) {
         val calleeFunction = expression.calleeReference.toResolvedCallableSymbol()
             ?: return
-        if (calleeFunction.isComposable(context.session)) {
+        if (calleeFunction.isComposable(context)) {
             checkComposableCall(expression, calleeFunction, context, reporter)
         }
     }
@@ -81,7 +81,7 @@
         // https://youtrack.jetbrains.com/issue/KT-47708.
         if (calleeFunction.origin == FirDeclarationOrigin.SamConstructor) return
 
-        if (calleeFunction.isComposable(context.session)) {
+        if (calleeFunction.isComposable(context)) {
             checkComposableCall(expression, calleeFunction, context, reporter)
         } else if (calleeFunction.callableId.isInvoke()) {
             checkInvoke(expression, context, reporter)
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableFunctionChecker.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableFunctionChecker.kt
index 3c19708..b845b45 100644
--- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableFunctionChecker.kt
+++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableFunctionChecker.kt
@@ -41,16 +41,20 @@
 
         // Check overrides for mismatched composable annotations
         for (override in declaration.getDirectOverriddenFunctions(context)) {
-            if (override.isComposable(context.session) != isComposable) {
+            if (override.isComposable(context) != isComposable) {
                 reporter.reportOn(
                     declaration.source,
                     FirErrors.CONFLICTING_OVERLOADS,
                     listOf(declaration.symbol, override),
                     context
                 )
+            } else if (override.isComposable(context) && !override.toScheme(context).canOverride(declaration.symbol.toScheme(context))) {
+                reporter.reportOn(
+                    source = declaration.source,
+                    factory = ComposeErrors.COMPOSE_APPLIER_DECLARATION_MISMATCH,
+                    context = context
+                )
             }
-
-            // TODO(b/282135108): Check scheme of override against declaration
         }
 
         // Check that `actual` composable declarations have composable expects
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableTargetChecker.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableTargetChecker.kt
new file mode 100644
index 0000000..782ca7e
--- /dev/null
+++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposableTargetChecker.kt
@@ -0,0 +1,356 @@
+/*
+ * Copyright 2010-2024 JetBrains s.r.o. and Kotlin Programming Language contributors.
+ * Use of this source code is governed by the Apache 2.0 license that can be found in the license/LICENSE.txt file.
+ */
+
+package androidx.compose.compiler.plugins.kotlin.k2
+
+import androidx.compose.compiler.plugins.kotlin.ComposeClassIds
+import androidx.compose.compiler.plugins.kotlin.ComposeFqNames
+import androidx.compose.compiler.plugins.kotlin.firTrace
+import androidx.compose.compiler.plugins.kotlin.inference.*
+import org.jetbrains.kotlin.diagnostics.DiagnosticReporter
+import org.jetbrains.kotlin.diagnostics.reportOn
+import org.jetbrains.kotlin.fir.FirAnnotationContainer
+import org.jetbrains.kotlin.fir.FirElement
+import org.jetbrains.kotlin.fir.analysis.checkers.MppCheckerKind
+import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
+import org.jetbrains.kotlin.fir.analysis.checkers.expression.FirFunctionCallChecker
+import org.jetbrains.kotlin.fir.declarations.FirAnonymousFunction
+import org.jetbrains.kotlin.fir.declarations.FirFunction
+import org.jetbrains.kotlin.fir.declarations.getAnnotationByClassId
+import org.jetbrains.kotlin.fir.expressions.*
+import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
+import org.jetbrains.kotlin.fir.resolve.toClassSymbol
+import org.jetbrains.kotlin.fir.scopes.impl.overrides
+import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
+import org.jetbrains.kotlin.fir.symbols.SymbolInternals
+import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
+import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
+import org.jetbrains.kotlin.fir.symbols.impl.FirValueParameterSymbol
+import org.jetbrains.kotlin.fir.types.ConeKotlinType
+import org.jetbrains.kotlin.fir.types.isPrimitive
+import org.jetbrains.kotlin.fir.types.isString
+import org.jetbrains.kotlin.fir.types.resolvedType
+import org.jetbrains.kotlin.name.ClassId
+import org.jetbrains.kotlin.name.Name
+import org.jetbrains.kotlin.util.slicedMap.BasicWritableSlice
+import org.jetbrains.kotlin.util.slicedMap.RewritePolicy
+import org.jetbrains.kotlin.util.slicedMap.WritableSlice
+
+private sealed class FirInferenceNode(val element: FirElement) {
+    open val kind: NodeKind get() = NodeKind.Expression
+    abstract val type: InferenceNodeType?
+    open val referenceContainer: FirInferenceNode? get() = null
+    open val parameterIndex: Int get() = -1
+    override fun hashCode(): Int = 31 * element.hashCode()
+    override fun equals(other: Any?): Boolean = other is FirInferenceNode && other.element == element
+}
+
+private open class FirElementInferenceNode(element: FirElement) : FirInferenceNode(element) {
+    override val type: InferenceNodeType? get() = null
+}
+
+private class FirCallableElementInferenceNode(val callable: FirCallableSymbol<*>, element: FirElement) : FirElementInferenceNode(element) {
+    override val type: InferenceNodeType = InferenceCallableType(callable)
+    override fun toString(): String = "${callable.name.toString()}()@${element.source?.startOffset}"
+}
+
+private class FirFunctionInferenceNode(val function: FirFunction) : FirInferenceNode(function) {
+    override val kind get() = NodeKind.Function
+    override val type = InferenceCallableType(function.symbol)
+    override fun toString(): String = function.symbol.name.toString()
+}
+
+private class FirLambdaInferenceNode(val lambda: FirAnonymousFunctionExpression): FirElementInferenceNode(lambda) {
+    override val kind: NodeKind get() = NodeKind.Lambda
+    override val type = InferenceCallableType(lambda.anonymousFunction.symbol)
+    override fun toString() = "<lambda:${lambda.source?.startOffset}>"
+}
+
+private class FirSamInferenceNode(context: CheckerContext, val sam: FirSamConversionExpression): FirElementInferenceNode(sam) {
+    override val kind: NodeKind get() = NodeKind.Lambda
+    override val type: InferenceNodeType? =
+        (sam.expression as? FirAnonymousFunctionExpression)?.let {
+            InferenceCallableType(it.anonymousFunction.symbol)
+        }
+
+    override fun toString(): String = "<sam:${sam.source?.startOffset}>"
+}
+
+private class FirParameterReferenceNode(
+    element: FirElement,
+    override val parameterIndex: Int,
+    override val referenceContainer: FirInferenceNode
+) : FirElementInferenceNode(element) {
+    override val kind: NodeKind get() = NodeKind.ParameterReference
+    override fun toString(): String = "param:$parameterIndex"
+}
+
+private fun callableInferenceNodeOf(expression: FirElement, callable: FirCallableSymbol<*>, context: CheckerContext) =
+    parameterInferenceNodeOrNull(expression, context) ?: (expression as? FirAnonymousFunction)?.let {
+        mapping[expression]
+    }?.let {
+        inferenceNodeOf(it, context)
+    } ?: FirCallableElementInferenceNode(callable, expression)
+
+private sealed class InferenceNodeType {
+    abstract fun toScheme(context: CheckerContext): Scheme
+    abstract fun isTypeFor(callable: FirCallableSymbol<*>): Boolean
+}
+
+private class InferenceCallableType(val callable: FirCallableSymbol<*>) : InferenceNodeType() {
+    override fun toScheme(context: CheckerContext): Scheme = callable.toScheme(context)
+    override fun isTypeFor(callable: FirCallableSymbol<*>) = this.callable.callableId == callable.callableId
+    override fun hashCode(): Int = 31 * callable.callableId.hashCode()
+    override fun equals(other: Any?): Boolean =
+        other is InferenceCallableType && other.callable.callableId == callable.callableId
+}
+
+fun FirCallableSymbol<*>.toScheme(context: CheckerContext): Scheme =
+    declaredScheme(context) ?: Scheme(
+        target = schemeItem(context).let {
+            // The item is unspecified see if the containing has an annotation we can use
+            if (it.isUnspecified) {
+                val target = fileScopeTarget(context)
+                if (target != null) return@let target
+            }
+            it
+        },
+        parameters = parameters(context).map { it.toScheme(context) }
+    ).mergeWith(methodOverrides(context).map { it.toScheme(context) })
+
+@OptIn(SymbolInternals::class)
+fun FirCallableSymbol<*>.methodOverrides(context: CheckerContext) = (fir as? FirFunction)?.getDirectOverriddenFunctions(context) ?: emptyList()
+
+fun FirCallableSymbol<*>.parameters(context: CheckerContext): List<FirValueParameterSymbol> =
+    (this as? FirFunctionSymbol<*>)?.let {
+        valueParameterSymbols.filter { it.isComposable(context) }
+    } ?: emptyList()
+
+@OptIn(SymbolInternals::class)
+private fun FirCallableSymbol<*>.fileScopeTarget(context: CheckerContext): Item? {
+    fun findFileScope(element: FirElement): Item? =
+        (element as? FirAnnotationContainer)?.compositionTarget(context)?.let { Token(it) } ?: element.parent?.let { findFileScope(it) }
+    return findFileScope(fir)
+}
+
+fun FirCallableSymbol<*>.declaredScheme(context: CheckerContext) =
+    (annotationArgument(
+        context,
+        ComposeClassIds.ComposableInferredTarget,
+        ComposeFqNames.ComposableInferredTargetSchemeArgument
+    ) as? String)?.let {
+        deserializeScheme(it)
+    }
+
+fun FirCallableSymbol<*>.schemeItem(context: CheckerContext): Item {
+    val explicitTarget = compositionTarget(context)
+    val explicitOpen = compositionOpenTarget(context)
+    return when {
+        explicitTarget != null -> Token(explicitTarget)
+        explicitOpen != null -> Open(explicitOpen)
+        else -> Open(-1, isUnspecified = true)
+    }
+}
+
+fun FirCallableSymbol<*>.compositionTarget(context: CheckerContext): String? =
+    annotationArgument(context, ComposeClassIds.ComposableTarget, ComposeFqNames.ComposableTargetApplierArgument) as? String ?: run {
+        annotations.firstNotNullOfOrNull {
+            it.resolvedType.targetName(context)
+        }
+    }
+
+fun ConeKotlinType.targetName(context: CheckerContext): String? = toClassSymbol(context.session)?.let { cls ->
+    cls.annotationArgument(context, ComposeClassIds.ComposableTargetMarker, ComposeFqNames.ComposableTargetMarkerDescriptionName)?.let {
+        if (it is String && it != "") {
+            it
+        } else cls.classId.asFqNameString()
+    }
+}
+
+fun FirCallableSymbol<*>.compositionOpenTarget(context: CheckerContext): Int? =
+    annotationArgument(context, ComposeClassIds.ComposableOpenTarget, ComposeFqNames.ComposableOpenTargetIndexArgument) as? Int
+
+fun FirBasedSymbol<*>.annotationArgument(context: CheckerContext, classId: ClassId, argumentName: Name) =
+    getAnnotationByClassId(classId, context.session)?.argument(argumentName)
+
+fun FirAnnotationContainer.compositionTarget(context: CheckerContext): String? =
+    annotationArgument(context, ComposeClassIds.ComposableTarget, ComposeFqNames.ComposableTargetApplierArgument) as? String ?: run {
+        annotations.firstNotNullOfOrNull {
+            it.resolvedType.targetName(context)
+        }
+    }
+
+fun FirAnnotationContainer.annotationArgument(context: CheckerContext, classId: ClassId, argumentName: Name) =
+    getAnnotationByClassId(classId, context.session)?.argument(argumentName)
+
+fun FirAnnotation.argument(name: Name): Any? = argumentMapping.mapping[name]?.let {
+    if ((it.resolvedType.isString || it.resolvedType.isPrimitive) && it is FirLiteralExpression)
+        it.value
+    else null
+}
+
+object ComposableTargetChecker : FirFunctionCallChecker(MppCheckerKind.Common) {
+    @OptIn(SymbolInternals::class)
+    override fun check(expression: FirFunctionCall, context: CheckerContext, reporter: DiagnosticReporter) {
+        val calleeFunction = expression.calleeReference.toResolvedCallableSymbol()
+            ?: return
+        if (calleeFunction.isComposable(context)) {
+            updateParents(context)
+            val infer = getInfer(context, reporter)
+            val call = inferenceNodeOf(expression, context)
+            val target = callableInferenceNodeOf(expression, calleeFunction, context)
+            val parameters = calleeFunction.parameters(context)
+            val argumentsMapping = expression.resolvedArgumentMapping
+            val arguments = parameters.mapNotNull { parameter ->
+                argumentsMapping?.firstNotNullOf { entry ->
+                    if (entry.value == parameter.fir)
+                        inferenceNodeOf(entry.key, context)
+                    else null
+                }
+            }
+            infer.visitCall(call, target, arguments)
+        }
+    }
+}
+
+private class FirApplierInference(
+    val context: CheckerContext,
+    var reporter: DiagnosticReporter
+) : ApplierInferencer<InferenceNodeType, FirInferenceNode>(
+    typeAdapter = object : TypeAdapter<InferenceNodeType> {
+        override fun declaredSchemaOf(type: InferenceNodeType): Scheme = type.toScheme(context)
+        override fun currentInferredSchemeOf(type: InferenceNodeType): Scheme? = null
+        override fun updatedInferredScheme(type: InferenceNodeType, scheme: Scheme) {}
+    },
+    nodeAdapter = object : NodeAdapter<InferenceNodeType, FirInferenceNode> {
+        override fun containerOf(node: FirInferenceNode): FirInferenceNode {
+            var current = node.element.parent
+            while (current != null) {
+                when (current) {
+                    is FirFunction -> return inferenceNodeOf(current, context)
+                }
+                current = current.parent
+            }
+            return node
+        }
+
+        override fun kindOf(node: FirInferenceNode): NodeKind = node.kind
+
+        override fun schemeParameterIndexOf(node: FirInferenceNode, container: FirInferenceNode): Int = node.parameterIndex
+
+        override fun typeOf(node: FirInferenceNode): InferenceNodeType? = node.type
+
+        override fun referencedContainerOf(node: FirInferenceNode): FirInferenceNode? = node.referenceContainer
+    },
+    errorReporter = object : ErrorReporter<FirInferenceNode> {
+        private fun descriptionFrom(token: String): String = token // TODO: find the message if appropriate
+        override fun reportCallError(node: FirInferenceNode, expected: String, received: String) {
+            if (expected != received) {
+                val expectedDescription = descriptionFrom(expected)
+                val receivedDescription = descriptionFrom(received)
+                reporter.reportOn(
+                    source = node.element.source,
+                    factory = ComposeErrors.COMPOSE_APPLIER_CALL_MISMATCH,
+                    context = context,
+                    a = expectedDescription,
+                    b = receivedDescription
+                )
+            }
+        }
+
+        override fun reportParameterError(node: FirInferenceNode, index: Int, expected: String, received: String) {
+            reporter.reportOn(
+                source = node.element.source,
+                factory = ComposeErrors.COMPOSE_APPLIER_PARAMETER_MISMATCH,
+                context = context,
+                a = expected,
+                b = received
+            )
+        }
+
+        override fun log(node: FirInferenceNode?, message: String) {
+        }
+    },
+    lazySchemeStorage = object : LazySchemeStorage<FirInferenceNode> {
+        override fun getLazyScheme(node: FirInferenceNode): LazyScheme? =
+            lazySchemes[node]
+
+        override fun storeLazyScheme(node: FirInferenceNode, value: LazyScheme) {
+            lazySchemes[node] = value
+        }
+    }
+)
+
+private var lazySchemes = mutableMapOf<FirInferenceNode, LazyScheme>()
+
+/**
+ * A map of elements that, for inference, needed to be treated as if they are identical such as lambdas
+ * and the anonymous function as well as sam conversions and the expression converted.
+ */
+private var mapping = mutableMapOf<FirElement, FirElement>()
+
+private fun inferenceNodeOf(element: FirElement, context: CheckerContext): FirInferenceNode =
+    element.firTrace[WritableSlices.NODE, element] ?: when (element) {
+        is FirAnonymousFunctionExpression -> run {
+            mapping[element.anonymousFunction] = element
+            FirLambdaInferenceNode(element)
+        }
+        is FirSamConversionExpression -> run {
+            (element.expression as? FirAnonymousFunctionExpression)?.let {
+                mapping[it.anonymousFunction] = element
+            }
+            FirSamInferenceNode(context, element)
+        }
+        is FirAnonymousFunction -> callableInferenceNodeOf(element, element.symbol, context)
+        is FirFunction ->
+            FirFunctionInferenceNode(element)
+        else -> parameterInferenceNodeOrNull(element, context) ?: FirElementInferenceNode(element)
+    }.also {
+        element.firTrace.record(WritableSlices.NODE, element, it)
+    }
+
+
+@OptIn(SymbolInternals::class)
+private fun parameterInferenceNodeOrNull(expression: FirElement, context: CheckerContext): FirInferenceNode? {
+    if (expression is FirFunctionCall) {
+        val receiver = expression.explicitReceiver as? FirQualifiedAccessExpression ?: return null
+        val parameterSymbol = receiver.toResolvedCallableSymbol() as? FirValueParameterSymbol ?: return null
+        val function = parameterSymbol.containingFunctionSymbol
+        val index = function.valueParameterSymbols.filter { it.isComposable(context) }.indexOf(parameterSymbol)
+        if (index >= 0) {
+            return FirParameterReferenceNode(expression, index, inferenceNodeOf(function.fir, context))
+        }
+    }
+    return null
+}
+
+private fun getInfer(context: CheckerContext, reporter: DiagnosticReporter): FirApplierInference {
+    return (context.firTrace[WritableSlices.INFER, context] ?: run { FirApplierInference(context, reporter) }).also {
+        it.reporter = reporter
+    }
+}
+
+private fun updateParents(context: CheckerContext) {
+    val containingElements = context.containingElements
+    for (i in (1..<containingElements.size).reversed()) {
+        val element = containingElements[i]
+        if (element.parent != null) break
+        val parent = containingElements[i - 1]
+        element.firTrace.record(WritableSlices.PARENT, element, parent)
+    }
+}
+
+private val FirElement.parent get() = firTrace[WritableSlices.PARENT, this]
+
+private object WritableSlices {
+    val INFER: WritableSlice<CheckerContext, FirApplierInference> =
+        BasicWritableSlice(RewritePolicy.DO_NOTHING)
+    val LAZY_SCHEME: WritableSlice<FirElement, LazyScheme> =
+        BasicWritableSlice(RewritePolicy.DO_NOTHING)
+    val NODE: WritableSlice<FirElement, FirInferenceNode> =
+        BasicWritableSlice(RewritePolicy.DO_NOTHING)
+    val PARENT: WritableSlice<FirElement, FirElement?> =
+        BasicWritableSlice(RewritePolicy.DO_NOTHING)
+}
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrorMessages.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrorMessages.kt
index af02ccd..7088454 100644
--- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrorMessages.kt
+++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrorMessages.kt
@@ -99,5 +99,24 @@
             ComposeErrors.MISMATCHED_COMPOSABLE_IN_EXPECT_ACTUAL,
             "Mismatched @Composable annotation between expect and actual declaration"
         )
+
+        map.put(
+            ComposeErrors.COMPOSE_APPLIER_CALL_MISMATCH,
+            "Calling a {1} composable function where a {0} composable was expected",
+            FirDiagnosticRenderers.TO_STRING,
+            FirDiagnosticRenderers.TO_STRING
+        )
+
+        map.put(
+            ComposeErrors.COMPOSE_APPLIER_PARAMETER_MISMATCH,
+            "A {1} composable parameter was provided where a {0} composable was expected",
+            FirDiagnosticRenderers.TO_STRING,
+            FirDiagnosticRenderers.TO_STRING
+        )
+
+        map.put(
+            ComposeErrors.COMPOSE_APPLIER_DECLARATION_MISMATCH,
+            "The composition target of an override must match the ancestor target"
+        )
     }
 }
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrors.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrors.kt
index 2c2ceaa..2e1c454 100644
--- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrors.kt
+++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeErrors.kt
@@ -20,17 +20,7 @@
 import com.intellij.openapi.util.TextRange
 import com.intellij.psi.PsiElement
 import com.intellij.util.diff.FlyweightCapableTreeStructure
-import org.jetbrains.kotlin.diagnostics.LightTreePositioningStrategies
-import org.jetbrains.kotlin.diagnostics.LightTreePositioningStrategy
-import org.jetbrains.kotlin.diagnostics.PositioningStrategies
-import org.jetbrains.kotlin.diagnostics.PositioningStrategy
-import org.jetbrains.kotlin.diagnostics.SourceElementPositioningStrategies
-import org.jetbrains.kotlin.diagnostics.SourceElementPositioningStrategy
-import org.jetbrains.kotlin.diagnostics.error0
-import org.jetbrains.kotlin.diagnostics.error2
-import org.jetbrains.kotlin.diagnostics.error3
-import org.jetbrains.kotlin.diagnostics.findChildByType
-import org.jetbrains.kotlin.diagnostics.markElement
+import org.jetbrains.kotlin.diagnostics.*
 import org.jetbrains.kotlin.diagnostics.rendering.RootDiagnosticRendererFactory
 import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
 import org.jetbrains.kotlin.fir.symbols.impl.FirValueParameterSymbol
@@ -51,7 +41,7 @@
     val NONREADONLY_CALL_IN_READONLY_COMPOSABLE by error0<PsiElement>()
 
     val CAPTURED_COMPOSABLE_INVOCATION by
-        error2<PsiElement, FirVariableSymbol<*>, FirCallableSymbol<*>>()
+    error2<PsiElement, FirVariableSymbol<*>, FirCallableSymbol<*>>()
 
     // composable calls are not allowed in try expressions
     // error goes on the `try` keyword
@@ -60,10 +50,10 @@
     )
 
     val MISSING_DISALLOW_COMPOSABLE_CALLS_ANNOTATION by error3<
-        PsiElement,
-        FirValueParameterSymbol, // unmarked
-        FirValueParameterSymbol, // marked
-        FirCallableSymbol<*>>()
+            PsiElement,
+            FirValueParameterSymbol, // unmarked
+            FirValueParameterSymbol, // marked
+            FirCallableSymbol<*>>()
 
     val ABSTRACT_COMPOSABLE_DEFAULT_PARAMETER_VALUE by error0<PsiElement>()
 
@@ -91,6 +81,14 @@
         SourceElementPositioningStrategies.DECLARATION_NAME
     )
 
+    val COMPOSE_APPLIER_CALL_MISMATCH by warning2<PsiElement, String, String>()
+
+    val COMPOSE_APPLIER_PARAMETER_MISMATCH by warning2<PsiElement, String, String>()
+
+    val COMPOSE_APPLIER_DECLARATION_MISMATCH by warning0<PsiElement>(
+        ComposeSourceElementPositioningStrategies.DECLARATION_NAME_OR_DEFAULT
+    )
+
     init {
         RootDiagnosticRendererFactory.registerFactory(ComposeErrorMessages)
     }
@@ -105,20 +103,20 @@
                 }
                 return PositioningStrategies.DEFAULT.mark(element)
             }
-    }
+        }
 
     private val LIGHT_TREE_TRY_KEYWORD: LightTreePositioningStrategy =
         object : LightTreePositioningStrategy() {
-        override fun mark(
-            node: LighterASTNode,
-            startOffset: Int,
-            endOffset: Int,
-            tree: FlyweightCapableTreeStructure<LighterASTNode>
-        ): List<TextRange> {
-            val target = tree.findChildByType(node, KtTokens.TRY_KEYWORD) ?: node
-            return markElement(target, startOffset, endOffset, tree, node)
+            override fun mark(
+                node: LighterASTNode,
+                startOffset: Int,
+                endOffset: Int,
+                tree: FlyweightCapableTreeStructure<LighterASTNode>,
+            ): List<TextRange> {
+                val target = tree.findChildByType(node, KtTokens.TRY_KEYWORD) ?: node
+                return markElement(target, startOffset, endOffset, tree, node)
+            }
         }
-    }
 
     private val PSI_DECLARATION_NAME_OR_DEFAULT: PositioningStrategy<PsiElement> =
         object : PositioningStrategy<PsiElement>() {
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeFirExtensions.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeFirExtensions.kt
index ee5a0f4..700454c 100644
--- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeFirExtensions.kt
+++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/ComposeFirExtensions.kt
@@ -103,7 +103,7 @@
 
     override val expressionCheckers: ExpressionCheckers = object : ExpressionCheckers() {
         override val functionCallCheckers: Set<FirFunctionCallChecker> =
-            setOf(ComposableFunctionCallChecker)
+            setOf(ComposableFunctionCallChecker, ComposableTargetChecker)
 
         override val propertyAccessExpressionCheckers: Set<FirPropertyAccessExpressionChecker> =
             setOf(ComposablePropertyAccessExpressionChecker)
diff --git a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/FirUtils.kt b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/FirUtils.kt
index 56d0cc0..c844504 100644
--- a/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/FirUtils.kt
+++ b/plugins/compose/compiler-hosted/src/main/java/androidx/compose/compiler/plugins/kotlin/k2/FirUtils.kt
@@ -17,39 +17,30 @@
 package androidx.compose.compiler.plugins.kotlin.k2
 
 import androidx.compose.compiler.plugins.kotlin.ComposeClassIds
+import org.jetbrains.kotlin.descriptors.Modality
 import org.jetbrains.kotlin.fir.FirAnnotationContainer
 import org.jetbrains.kotlin.fir.FirSession
 import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
 import org.jetbrains.kotlin.fir.analysis.checkers.getAnnotationStringParameter
 import org.jetbrains.kotlin.fir.analysis.checkers.unsubstitutedScope
 import org.jetbrains.kotlin.fir.containingClassLookupTag
-import org.jetbrains.kotlin.fir.declarations.FirFunction
-import org.jetbrains.kotlin.fir.declarations.FirPropertyAccessor
-import org.jetbrains.kotlin.fir.declarations.FirResolvePhase
-import org.jetbrains.kotlin.fir.declarations.hasAnnotation
+import org.jetbrains.kotlin.fir.declarations.*
 import org.jetbrains.kotlin.fir.declarations.utils.isOverride
+import org.jetbrains.kotlin.fir.declarations.utils.modality
 import org.jetbrains.kotlin.fir.expressions.FirFunctionCall
 import org.jetbrains.kotlin.fir.expressions.FirReturnExpression
 import org.jetbrains.kotlin.fir.references.toResolvedCallableSymbol
+import org.jetbrains.kotlin.fir.resolve.toClassSymbol
 import org.jetbrains.kotlin.fir.resolve.toSymbol
+import org.jetbrains.kotlin.fir.scopes.collectAllFunctions
 import org.jetbrains.kotlin.fir.scopes.getDirectOverriddenFunctions
 import org.jetbrains.kotlin.fir.scopes.getDirectOverriddenProperties
+import org.jetbrains.kotlin.fir.scopes.unsubstitutedScope
 import org.jetbrains.kotlin.fir.symbols.FirBasedSymbol
 import org.jetbrains.kotlin.fir.symbols.SymbolInternals
-import org.jetbrains.kotlin.fir.symbols.impl.FirCallableSymbol
-import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
-import org.jetbrains.kotlin.fir.symbols.impl.FirFunctionSymbol
-import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol
-import org.jetbrains.kotlin.fir.symbols.impl.FirPropertyAccessorSymbol
-import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
+import org.jetbrains.kotlin.fir.symbols.impl.*
 import org.jetbrains.kotlin.fir.symbols.lazyResolveToPhase
-import org.jetbrains.kotlin.fir.types.ConeKotlinType
-import org.jetbrains.kotlin.fir.types.ProjectionKind
-import org.jetbrains.kotlin.fir.types.coneType
-import org.jetbrains.kotlin.fir.types.isArrayType
-import org.jetbrains.kotlin.fir.types.isString
-import org.jetbrains.kotlin.fir.types.isUnit
-import org.jetbrains.kotlin.fir.types.type
+import org.jetbrains.kotlin.fir.types.*
 import org.jetbrains.kotlin.name.JvmStandardClassIds
 
 fun FirAnnotationContainer.hasComposableAnnotation(session: FirSession): Boolean =
@@ -67,17 +58,36 @@
 fun FirAnnotationContainer.hasDisallowComposableCallsAnnotation(session: FirSession): Boolean =
     hasAnnotation(ComposeClassIds.DisallowComposableCalls, session)
 
-fun FirCallableSymbol<*>.isComposable(session: FirSession): Boolean =
+fun FirAnnotationContainer.hasComposableTargetMarkerAnnotation(session: FirSession): Boolean =
+    hasAnnotation(ComposeClassIds.ComposableTargetMarker, session)
+
+fun FirCallableSymbol<*>.isComposable(context: CheckerContext): Boolean =
     when (this) {
         is FirFunctionSymbol<*> ->
-            hasComposableAnnotation(session)
+            hasComposableAnnotation(context.session)
         is FirPropertySymbol ->
             getterSymbol?.let {
-                it.hasComposableAnnotation(session) || it.isComposableDelegate(session)
+                it.hasComposableAnnotation(context.session) || it.isComposableDelegate(context)
             } ?: false
+        is FirValueParameterSymbol -> isComposable(context)
         else -> false
     }
 
+private fun FirValueParameterSymbol.isComposable(context: CheckerContext): Boolean =
+    resolvedReturnType.customAnnotations.hasAnnotation(ComposeClassIds.Composable, context.session) ||
+            findSamFunction(context)?.isComposable(context) == true
+
+private fun FirValueParameterSymbol.findSamFunction(context: CheckerContext): FirNamedFunctionSymbol? {
+    val type = resolvedReturnType
+    val session = context.session
+    val classSymbol = type.toClassSymbol(session) ?: return null
+    val samFunction = classSymbol
+        .unsubstitutedScope(session, context.scopeSession, withForcedTypeCalculator = true, memberRequiredPhase = null)
+        .collectAllFunctions()
+        .singleOrNull { it.modality == Modality.ABSTRACT }
+    return samFunction
+}
+
 fun FirCallableSymbol<*>.isReadOnlyComposable(session: FirSession): Boolean =
     when (this) {
         is FirFunctionSymbol<*> ->
@@ -88,7 +98,7 @@
     }
 
 @OptIn(SymbolInternals::class)
-private fun FirPropertyAccessorSymbol.isComposableDelegate(session: FirSession): Boolean {
+private fun FirPropertyAccessorSymbol.isComposableDelegate(context: CheckerContext): Boolean {
     if (!propertySymbol.hasDelegate) return false
     fir.lazyResolveToPhase(FirResolvePhase.BODY_RESOLVE)
     return ((fir
@@ -98,7 +108,7 @@
         ?.result as? FirFunctionCall)
         ?.calleeReference
         ?.toResolvedCallableSymbol()
-        ?.isComposable(session)
+        ?.isComposable(context)
         ?: false
 }
 
@@ -173,5 +183,5 @@
 
 private val FirFunctionSymbol<*>.explicitParameterTypes: List<ConeKotlinType>
     get() = resolvedContextReceivers.map { it.typeRef.coneType } +
-        listOfNotNull(receiverParameter?.typeRef?.coneType) +
-        valueParameterSymbols.map { it.resolvedReturnType }
+            listOfNotNull(receiverParameter?.typeRef?.coneType) +
+            valueParameterSymbols.map { it.resolvedReturnType }