[LL API] Add backend-computed code fragment mappings
diff --git a/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCompilerFacility.kt b/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCompilerFacility.kt index 6cf5155..84211a8 100644 --- a/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCompilerFacility.kt +++ b/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCompilerFacility.kt
@@ -14,10 +14,7 @@ import org.jetbrains.kotlin.analysis.api.fir.KtFirAnalysisSession import org.jetbrains.kotlin.analysis.api.impl.base.util.KtCompiledFileForOutputFile import org.jetbrains.kotlin.analysis.low.level.api.fir.LLFirInternals -import org.jetbrains.kotlin.analysis.low.level.api.fir.api.DiagnosticCheckerFilter -import org.jetbrains.kotlin.analysis.low.level.api.fir.api.LLFirResolveSession -import org.jetbrains.kotlin.analysis.low.level.api.fir.api.collectDiagnosticsForFile -import org.jetbrains.kotlin.analysis.low.level.api.fir.api.getOrBuildFirFile +import org.jetbrains.kotlin.analysis.low.level.api.fir.api.* import org.jetbrains.kotlin.analysis.low.level.api.fir.api.targets.LLFirWholeFileResolveTarget import org.jetbrains.kotlin.analysis.low.level.api.fir.api.targets.resolve import org.jetbrains.kotlin.analysis.low.level.api.fir.compile.CodeFragmentCapturedValueAnalyzer @@ -42,8 +39,7 @@ import org.jetbrains.kotlin.fir.backend.* import org.jetbrains.kotlin.fir.backend.jvm.* import org.jetbrains.kotlin.fir.declarations.FirDeclaration -import org.jetbrains.kotlin.fir.declarations.FirFile -import org.jetbrains.kotlin.fir.declarations.FirResolvePhase +import org.jetbrains.kotlin.fir.declarations.* import org.jetbrains.kotlin.fir.pipeline.applyIrGenerationExtensions import org.jetbrains.kotlin.fir.pipeline.signatureComposerForJvmFir2Ir import org.jetbrains.kotlin.fir.psi @@ -57,8 +53,15 @@ import org.jetbrains.kotlin.ir.backend.jvm.serialization.JvmIrMangler import org.jetbrains.kotlin.ir.declarations.* import org.jetbrains.kotlin.ir.declarations.impl.IrFactoryImpl +import org.jetbrains.kotlin.ir.descriptors.IrBasedDeclarationDescriptor +import org.jetbrains.kotlin.ir.descriptors.IrBasedReceiverParameterDescriptor +import org.jetbrains.kotlin.ir.descriptors.IrBasedValueParameterDescriptor +import org.jetbrains.kotlin.ir.descriptors.IrBasedVariableDescriptor import org.jetbrains.kotlin.ir.expressions.* +import org.jetbrains.kotlin.ir.symbols.IrClassSymbol import org.jetbrains.kotlin.ir.symbols.IrSymbol +import org.jetbrains.kotlin.ir.types.IrSimpleType +import org.jetbrains.kotlin.ir.util.classId import org.jetbrains.kotlin.ir.util.StubGeneratorExtensions import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid @@ -71,6 +74,7 @@ import org.jetbrains.kotlin.psi2ir.generators.fragments.EvaluatorFragmentInfo import org.jetbrains.kotlin.resolve.source.PsiSourceFile import org.jetbrains.kotlin.serialization.deserialization.descriptors.DeserializedContainerSource +import org.jetbrains.kotlin.utils.addIfNotNull import org.jetbrains.kotlin.utils.addToStdlib.runIf import java.util.Collections @@ -201,7 +205,16 @@ } val outputFiles = generationState.factory.asList().map(::KtCompiledFileForOutputFile) - val capturedValues = codeFragmentMappings?.capturedValues ?: emptyList() + val capturedValues = buildList { + if (codeFragmentMappings != null) { + addAll(codeFragmentMappings.capturedValues) + } + for ((_, _, descriptor) in generationState.newFragmentCaptureParameters) { + if (descriptor is IrBasedDeclarationDescriptor<*>) { + addIfNotNull(computeAdditionalCodeFragmentMapping(descriptor)) + } + } + } return KtCompilationResult.Success(outputFiles, capturedValues) } finally { generationState.destroy() @@ -222,6 +235,50 @@ val patchingVisitor = IrDeclarationPatchingVisitor(collectingVisitor.mappings) irCodeFragmentFiles.forEach { it.acceptVoid(patchingVisitor) } } + + private fun computeAdditionalCodeFragmentMapping(descriptor: IrBasedDeclarationDescriptor<*>): CodeFragmentCapturedValue? { + val owner = descriptor.owner + + if (descriptor is IrBasedReceiverParameterDescriptor && owner is IrValueParameter) { + val receiverClass = (owner.type as? IrSimpleType)?.classifier as? IrClassSymbol + val receiverClassId = receiverClass?.owner?.classId + + if (receiverClassId != null) { + if (owner.index >= 0) { + val labelName = receiverClassId.shortClassName + return CodeFragmentCapturedValue.ContextReceiver(owner.index, labelName, isCrossingInlineBounds = true) + } + + val parent = owner.parent + if (parent is IrFunction) { + if (parent.dispatchReceiverParameter == owner) { + return CodeFragmentCapturedValue.ContainingClass(receiverClassId, isCrossingInlineBounds = true) + } + + return CodeFragmentCapturedValue.ExtensionReceiver(parent.name.asString(), isCrossingInlineBounds = true) + } + } + } + + if (descriptor is IrBasedVariableDescriptor && owner is IrVariable) { + val name = owner.name + val isMutated = false // TODO capture the usage somehow + + if (owner.origin == IrDeclarationOrigin.PROPERTY_DELEGATE) { + return CodeFragmentCapturedValue.LocalDelegate(name, isMutated, isCrossingInlineBounds = true) + } + + return CodeFragmentCapturedValue.Local(name, isMutated, isCrossingInlineBounds = true) + } + + if (descriptor is IrBasedValueParameterDescriptor && owner is IrValueParameter) { + val name = owner.name + return CodeFragmentCapturedValue.Local(name, isMutated = false, isCrossingInlineBounds = true) + } + + return null + } + private fun getFullyResolvedFirFile(file: KtFile): FirFile { val firFile = file.getOrBuildFirFile(firResolveSession) LLFirWholeFileResolveTarget(firFile).resolve(FirResolvePhase.BODY_RESOLVE)