Partially implement invoke resolution
diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/calls/CallResolver.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/calls/CallResolver.kt index 2af5b46..6309e4f 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/calls/CallResolver.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/calls/CallResolver.kt
@@ -96,7 +96,8 @@ val implicitExtensionReceiverValue: ImplicitReceiverValue?, val explicitReceiverKind: ExplicitReceiverKind, private val inferenceComponents: InferenceComponents, - private val baseSystem: ConstraintStorage + private val baseSystem: ConstraintStorage, + val callInfo: CallInfo ) { val system by lazy { val system = inferenceComponents.createConstraintSystem() @@ -280,7 +281,8 @@ if (candidate.hasConsistentExtensionReceiver(extensionReceiver) && candidate.dispatchReceiverValue() == null) { processor.consumeCandidate( candidate as T, dispatchReceiverValue = null, - implicitExtensionReceiverValue = implicitExtensionReceiver) + implicitExtensionReceiverValue = implicitExtensionReceiver + ) } else { ProcessorAction.NEXT } @@ -349,7 +351,7 @@ towerScopeLevel: TowerScopeLevel, group: Int ): ProcessorAction { - if (checkSkip(group, resultCollector)) return ProcessorAction.NEXT + if (skipGroup(group, resultCollector)) return ProcessorAction.NEXT if (kind != TowerDataKind.EMPTY) return ProcessorAction.NEXT return QualifiedReceiverTowerLevel(session).processElementsByName( @@ -388,12 +390,20 @@ group: Int ): ProcessorAction + private var processed = -1 private var stopGroup = Int.MAX_VALUE - fun checkSkip(group: Int, resultCollector: CandidateCollector): Boolean { + fun skipGroup(group: Int, resultCollector: CandidateCollector): Boolean { if (resultCollector.isSuccess() && stopGroup == Int.MAX_VALUE) { stopGroup = group } - return group > stopGroup + if (group > processed) { + processed = group + } + if (group < processed) { + return true + } + if (group > stopGroup) return true + return false } } @@ -452,6 +462,17 @@ resultCollector: CandidateCollector, callResolver: CallResolver ): TowerDataConsumer { + val varCallInfo = CallInfo( + CallKind.VariableAccess, + callInfo.explicitReceiver, + emptyList(), + callInfo.isSafeCall, + callInfo.typeArguments, + inferenceComponents.session, + callInfo.containingFile, + callInfo.container, + callInfo.typeProvider + ) return PrioritizedTowerDataConsumer( resultCollector, createSimpleConsumer( @@ -462,29 +483,23 @@ inferenceComponents, resultCollector ), - createSimpleConsumer( - session, - name, - TowerScopeLevel.Token.Properties, - callInfo, - inferenceComponents, - InvokeCandidateCollector( - callResolver, - varCallInfo = CallInfo( - CallKind.VariableAccess, - callInfo.explicitReceiver, - emptyList(), - callInfo.isSafeCall, - callInfo.typeArguments, - inferenceComponents.session, - callInfo.containingFile, - callInfo.container, - callInfo.typeProvider - ), - invokeCallInfo = callInfo, - components = inferenceComponents + MultiplexerTowerDataConsumer(resultCollector).apply { + addConsumer( + createSimpleConsumer( + session, + name, + TowerScopeLevel.Token.Properties, + varCallInfo, + inferenceComponents, + InvokeCandidateCollector( + callResolver, + invokeCallInfo = callInfo, + components = inferenceComponents, + multiplexer = this + ) + ) ) - ) + } ) } @@ -534,7 +549,7 @@ towerScopeLevel: TowerScopeLevel, group: Int ): ProcessorAction { - if (checkSkip(group, resultCollector)) return ProcessorAction.NEXT + if (skipGroup(group, resultCollector)) return ProcessorAction.NEXT for ((index, consumer) in consumers.withIndex()) { val action = consumer.consume(kind, towerScopeLevel, group * consumers.size + index) if (action.stop()) { @@ -545,6 +560,52 @@ } } +class MultiplexerTowerDataConsumer( + val resultCollector: CandidateCollector +) : TowerDataConsumer() { + + val consumers = mutableListOf<TowerDataConsumer>() + val newConsumers = mutableListOf<TowerDataConsumer>() + + val kinds = mutableListOf<TowerDataKind>() + val groups = mutableListOf<Int>() + val levels = mutableListOf<TowerScopeLevel>() + + override fun consume( + kind: TowerDataKind, + towerScopeLevel: TowerScopeLevel, + group: Int + ): ProcessorAction { + if (skipGroup(group, resultCollector)) return ProcessorAction.NEXT + consumers += newConsumers + newConsumers.clear() + kinds += kind + groups += group + levels += towerScopeLevel + + for (consumer in consumers) { + val action = consumer.consume(kind, towerScopeLevel, group) + if (action.stop()) { + return ProcessorAction.STOP + } + } + return ProcessorAction.NEXT + } + + fun addConsumer(consumer: TowerDataConsumer): ProcessorAction = + run { + for (index in kinds.indices) { + if (consumer.consume(kinds[index], levels[index], groups[index]).stop()) { + return@run ProcessorAction.STOP + } + } + return@run ProcessorAction.NEXT + }.also { + newConsumers += consumer + } +} + + class ExplicitReceiverTowerDataConsumer<T : ConeSymbol>( val session: FirSession, val name: Name, @@ -560,7 +621,7 @@ towerScopeLevel: TowerScopeLevel, group: Int ): ProcessorAction { - if (checkSkip(group, resultCollector)) return ProcessorAction.NEXT + if (skipGroup(group, resultCollector)) return ProcessorAction.NEXT return when (kind) { TowerDataKind.EMPTY -> MemberScopeTowerLevel(session, explicitReceiver).processElementsByName( @@ -631,7 +692,7 @@ towerScopeLevel: TowerScopeLevel, group: Int ): ProcessorAction { - if (checkSkip(group, resultCollector)) return ProcessorAction.NEXT + if (skipGroup(group, resultCollector)) return ProcessorAction.NEXT return when (kind) { TowerDataKind.TOWER_LEVEL -> { @@ -703,20 +764,14 @@ return group } - val collector = CandidateCollector(callInfo!!, components) - private lateinit var towerDataConsumer: TowerDataConsumer + val collector by lazy { CandidateCollector(components) } + lateinit var towerDataConsumer: TowerDataConsumer private lateinit var implicitReceiverValues: List<ImplicitReceiverValue> - fun runTowerResolver(towerDataConsumer: TowerDataConsumer, implicitReceiverValues: List<ImplicitReceiverValue>): CandidateCollector { - this.towerDataConsumer = towerDataConsumer + fun runTowerResolver(consumer: TowerDataConsumer, implicitReceiverValues: List<ImplicitReceiverValue>): CandidateCollector { this.implicitReceiverValues = implicitReceiverValues + towerDataConsumer = consumer - runTowerResolver() - - return collector - } - - fun runTowerResolver(towerDataConsumer: TowerDataConsumer = this.towerDataConsumer) { var group = 0 towerDataConsumer.consume(TowerDataKind.EMPTY, TowerScopeLevel.Empty, group++) @@ -739,8 +794,9 @@ } - } + return collector + } } @@ -755,9 +811,7 @@ } -var ID = "" - -open class CandidateCollector(val callInfo: CallInfo, val components: InferenceComponents) { +open class CandidateCollector(val components: InferenceComponents) { val groupNumbers = mutableListOf<Int>() val candidates = mutableListOf<Candidate>() @@ -780,8 +834,8 @@ val sink = CheckerSinkImpl(components) var finished = false sink.continuation = suspend { - for (stage in callInfo.callKind.sequence()) { - stage.check(candidate, sink, callInfo) + for (stage in candidate.callInfo.callKind.sequence()) { + stage.check(candidate, sink, candidate.callInfo) } }.createCoroutineUnintercepted(completion = object : Continuation<Unit> { override val context: CoroutineContext @@ -847,8 +901,11 @@ } class InvokeCandidateCollector( - val callResolver: CallResolver, val varCallInfo: CallInfo, val invokeCallInfo: CallInfo, components: InferenceComponents -) : CandidateCollector(varCallInfo, components) { + val callResolver: CallResolver, + val invokeCallInfo: CallInfo, + components: InferenceComponents, + val multiplexer: MultiplexerTowerDataConsumer +) : CandidateCollector(components) { override fun consumeCandidate(group: Int, candidate: Candidate): CandidateApplicability { val applicability = super.consumeCandidate(group, candidate) @@ -858,7 +915,13 @@ val boundInvokeCallInfo = CallInfo( invokeCallInfo.callKind, FirQualifiedAccessExpressionImpl(session, null, false).apply { - calleeReference = FirNamedReferenceWithCandidate(session, null, (candidate.symbol as ConeCallableSymbol).callableId.callableName, candidate) + calleeReference = FirNamedReferenceWithCandidate( + session, + null, + (candidate.symbol as ConeCallableSymbol).callableId.callableName, + candidate + ) + typeRef = callResolver.typeCalculator.tryCalculateReturnType(candidate.symbol.firUnsafe()) }, invokeCallInfo.arguments, invokeCallInfo.isSafeCall, @@ -868,9 +931,10 @@ invokeCallInfo.container, invokeCallInfo.typeProvider ) - val invokeConsumer = createSimpleFunctionConsumer(session, Name.identifier("invoke"), boundInvokeCallInfo, components, callResolver.collector) + val invokeConsumer = + createSimpleFunctionConsumer(session, Name.identifier("invoke"), boundInvokeCallInfo, components, callResolver.collector) - callResolver.runTowerResolver(invokeConsumer) + multiplexer.addConsumer(invokeConsumer) } return applicability
diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/calls/CandidateFactory.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/calls/CandidateFactory.kt index 4929f89b0..fc76643 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/calls/CandidateFactory.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/calls/CandidateFactory.kt
@@ -15,7 +15,7 @@ class CandidateFactory( val inferenceComponents: InferenceComponents, - callInfo: CallInfo + val callInfo: CallInfo ) { val baseSystem: ConstraintStorage @@ -37,7 +37,7 @@ ): Candidate { return Candidate( symbol, dispatchReceiverValue, implicitExtensionReceiverValue, - explicitReceiverKind, inferenceComponents, baseSystem + explicitReceiverKind, inferenceComponents, baseSystem, callInfo ) } }
diff --git a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/transformers/FirBodyResolveTransformer.kt b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/transformers/FirBodyResolveTransformer.kt index c584450..e744342 100644 --- a/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/transformers/FirBodyResolveTransformer.kt +++ b/compiler/fir/resolve/src/org/jetbrains/kotlin/fir/resolve/transformers/FirBodyResolveTransformer.kt
@@ -348,7 +348,8 @@ val consumer = createVariableAndObjectConsumer( session, callee.name, - info, inferenceComponents + info, inferenceComponents, + resolver.collector ) val result = resolver.runTowerResolver(consumer, implicitReceiverStack.asReversed()) @@ -551,7 +552,7 @@ resolver.callInfo = info resolver.scopes = (scopes + localScopes).asReversed() - val consumer = createFunctionConsumer(session, name, info, inferenceComponents) + val consumer = createFunctionConsumer(session, name, info, inferenceComponents, resolver.collector, resolver) val result = resolver.runTowerResolver(consumer, implicitReceiverStack.asReversed()) val bestCandidates = result.bestCandidates() val reducedCandidates = if (result.currentApplicability < CandidateApplicability.SYNTHETIC_RESOLVED) {
diff --git a/compiler/fir/resolve/testData/resolve/expresssions/invoke/explicitReceiver.txt b/compiler/fir/resolve/testData/resolve/expresssions/invoke/explicitReceiver.txt index 945af23..4ee5801 100644 --- a/compiler/fir/resolve/testData/resolve/expresssions/invoke/explicitReceiver.txt +++ b/compiler/fir/resolve/testData/resolve/expresssions/invoke/explicitReceiver.txt
@@ -13,8 +13,8 @@ ^invoke this# } - public final fun bar(): R|kotlin/Unit| { - ^bar R|/x|() + public final fun bar(): R|Foo| { + ^bar R|/Foo.invoke|() } }
diff --git a/compiler/fir/resolve/testData/resolve/expresssions/invoke/explicitReceiver2.txt b/compiler/fir/resolve/testData/resolve/expresssions/invoke/explicitReceiver2.txt index 796a513..78d6d9d 100644 --- a/compiler/fir/resolve/testData/resolve/expresssions/invoke/explicitReceiver2.txt +++ b/compiler/fir/resolve/testData/resolve/expresssions/invoke/explicitReceiver2.txt
@@ -23,8 +23,8 @@ public final val x: R|Bar| = R|/Bar.Bar|() public get(): R|Bar| - public final fun bar(): R|kotlin/Unit| { - ^bar R|/x|() + public final fun bar(): R|Foo| { + ^bar R|/Bar.invoke|() } }
diff --git a/compiler/fir/resolve/testData/resolve/expresssions/invoke/extension.kt b/compiler/fir/resolve/testData/resolve/expresssions/invoke/extension.kt index 2c22274..ebe97fb 100644 --- a/compiler/fir/resolve/testData/resolve/expresssions/invoke/extension.kt +++ b/compiler/fir/resolve/testData/resolve/expresssions/invoke/extension.kt
@@ -7,5 +7,5 @@ val x = 0 - fun foo() = x() + fun foo() = x() // should resolve to invoke } \ No newline at end of file
diff --git a/compiler/fir/resolve/testData/resolve/expresssions/invoke/farInvokeExtension.kt b/compiler/fir/resolve/testData/resolve/expresssions/invoke/farInvokeExtension.kt index eaf10f0..3b46415 100644 --- a/compiler/fir/resolve/testData/resolve/expresssions/invoke/farInvokeExtension.kt +++ b/compiler/fir/resolve/testData/resolve/expresssions/invoke/farInvokeExtension.kt
@@ -8,5 +8,5 @@ val x = 0 - fun foo() = x() + fun foo() = x() // should resolve to fun x } \ No newline at end of file
diff --git a/compiler/fir/resolve/testData/resolve/expresssions/invoke/implicitTypeOrder.kt b/compiler/fir/resolve/testData/resolve/expresssions/invoke/implicitTypeOrder.kt index 47deb73..81e849c 100644 --- a/compiler/fir/resolve/testData/resolve/expresssions/invoke/implicitTypeOrder.kt +++ b/compiler/fir/resolve/testData/resolve/expresssions/invoke/implicitTypeOrder.kt
@@ -1,6 +1,6 @@ class A { - fun bar() = foo() + fun bar() = foo() // should resolve to invoke fun invoke() = this }
diff --git a/compiler/fir/resolve/testData/resolve/expresssions/invoke/implicitTypeOrder.txt b/compiler/fir/resolve/testData/resolve/expresssions/invoke/implicitTypeOrder.txt index e470dc3..85af0fc 100644 --- a/compiler/fir/resolve/testData/resolve/expresssions/invoke/implicitTypeOrder.txt +++ b/compiler/fir/resolve/testData/resolve/expresssions/invoke/implicitTypeOrder.txt
@@ -4,8 +4,8 @@ super<R|kotlin/Any|>() } - public final fun bar(): <ERROR TYPE REF: Unresolved name: foo> { - ^bar <Unresolved name: foo>#() + public final fun bar(): R|A| { + ^bar R|/A.invoke|() } public final fun invoke(): R|A| {
diff --git a/compiler/fir/resolve/testData/resolve/expresssions/invoke/simple.txt b/compiler/fir/resolve/testData/resolve/expresssions/invoke/simple.txt index c12629e..06667b5 100644 --- a/compiler/fir/resolve/testData/resolve/expresssions/invoke/simple.txt +++ b/compiler/fir/resolve/testData/resolve/expresssions/invoke/simple.txt
@@ -10,5 +10,5 @@ } public final fun test(s: R|Simple|): R|kotlin/Unit| { - lval result: <ERROR TYPE REF: Unresolved name: s> = <Unresolved name: s>#() + lval result: R|kotlin/String| = R|/Simple.invoke|() }