[LL API] Add backend-computed code fragment mappings
diff --git a/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCompilerFacility.kt b/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCompilerFacility.kt
index 6cf5155..84211a8 100644
--- a/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCompilerFacility.kt
+++ b/analysis/analysis-api-fir/src/org/jetbrains/kotlin/analysis/api/fir/components/KtFirCompilerFacility.kt
@@ -14,10 +14,7 @@
import org.jetbrains.kotlin.analysis.api.fir.KtFirAnalysisSession
import org.jetbrains.kotlin.analysis.api.impl.base.util.KtCompiledFileForOutputFile
import org.jetbrains.kotlin.analysis.low.level.api.fir.LLFirInternals
-import org.jetbrains.kotlin.analysis.low.level.api.fir.api.DiagnosticCheckerFilter
-import org.jetbrains.kotlin.analysis.low.level.api.fir.api.LLFirResolveSession
-import org.jetbrains.kotlin.analysis.low.level.api.fir.api.collectDiagnosticsForFile
-import org.jetbrains.kotlin.analysis.low.level.api.fir.api.getOrBuildFirFile
+import org.jetbrains.kotlin.analysis.low.level.api.fir.api.*
import org.jetbrains.kotlin.analysis.low.level.api.fir.api.targets.LLFirWholeFileResolveTarget
import org.jetbrains.kotlin.analysis.low.level.api.fir.api.targets.resolve
import org.jetbrains.kotlin.analysis.low.level.api.fir.compile.CodeFragmentCapturedValueAnalyzer
@@ -42,8 +39,7 @@
import org.jetbrains.kotlin.fir.backend.*
import org.jetbrains.kotlin.fir.backend.jvm.*
import org.jetbrains.kotlin.fir.declarations.FirDeclaration
-import org.jetbrains.kotlin.fir.declarations.FirFile
-import org.jetbrains.kotlin.fir.declarations.FirResolvePhase
+import org.jetbrains.kotlin.fir.declarations.*
import org.jetbrains.kotlin.fir.pipeline.applyIrGenerationExtensions
import org.jetbrains.kotlin.fir.pipeline.signatureComposerForJvmFir2Ir
import org.jetbrains.kotlin.fir.psi
@@ -57,8 +53,15 @@
import org.jetbrains.kotlin.ir.backend.jvm.serialization.JvmIrMangler
import org.jetbrains.kotlin.ir.declarations.*
import org.jetbrains.kotlin.ir.declarations.impl.IrFactoryImpl
+import org.jetbrains.kotlin.ir.descriptors.IrBasedDeclarationDescriptor
+import org.jetbrains.kotlin.ir.descriptors.IrBasedReceiverParameterDescriptor
+import org.jetbrains.kotlin.ir.descriptors.IrBasedValueParameterDescriptor
+import org.jetbrains.kotlin.ir.descriptors.IrBasedVariableDescriptor
import org.jetbrains.kotlin.ir.expressions.*
+import org.jetbrains.kotlin.ir.symbols.IrClassSymbol
import org.jetbrains.kotlin.ir.symbols.IrSymbol
+import org.jetbrains.kotlin.ir.types.IrSimpleType
+import org.jetbrains.kotlin.ir.util.classId
import org.jetbrains.kotlin.ir.util.StubGeneratorExtensions
import org.jetbrains.kotlin.ir.visitors.IrElementVisitorVoid
import org.jetbrains.kotlin.ir.visitors.acceptChildrenVoid
@@ -71,6 +74,7 @@
import org.jetbrains.kotlin.psi2ir.generators.fragments.EvaluatorFragmentInfo
import org.jetbrains.kotlin.resolve.source.PsiSourceFile
import org.jetbrains.kotlin.serialization.deserialization.descriptors.DeserializedContainerSource
+import org.jetbrains.kotlin.utils.addIfNotNull
import org.jetbrains.kotlin.utils.addToStdlib.runIf
import java.util.Collections
@@ -201,7 +205,16 @@
}
val outputFiles = generationState.factory.asList().map(::KtCompiledFileForOutputFile)
- val capturedValues = codeFragmentMappings?.capturedValues ?: emptyList()
+ val capturedValues = buildList {
+ if (codeFragmentMappings != null) {
+ addAll(codeFragmentMappings.capturedValues)
+ }
+ for ((_, _, descriptor) in generationState.newFragmentCaptureParameters) {
+ if (descriptor is IrBasedDeclarationDescriptor<*>) {
+ addIfNotNull(computeAdditionalCodeFragmentMapping(descriptor))
+ }
+ }
+ }
return KtCompilationResult.Success(outputFiles, capturedValues)
} finally {
generationState.destroy()
@@ -222,6 +235,50 @@
val patchingVisitor = IrDeclarationPatchingVisitor(collectingVisitor.mappings)
irCodeFragmentFiles.forEach { it.acceptVoid(patchingVisitor) }
}
+
+ private fun computeAdditionalCodeFragmentMapping(descriptor: IrBasedDeclarationDescriptor<*>): CodeFragmentCapturedValue? {
+ val owner = descriptor.owner
+
+ if (descriptor is IrBasedReceiverParameterDescriptor && owner is IrValueParameter) {
+ val receiverClass = (owner.type as? IrSimpleType)?.classifier as? IrClassSymbol
+ val receiverClassId = receiverClass?.owner?.classId
+
+ if (receiverClassId != null) {
+ if (owner.index >= 0) {
+ val labelName = receiverClassId.shortClassName
+ return CodeFragmentCapturedValue.ContextReceiver(owner.index, labelName, isCrossingInlineBounds = true)
+ }
+
+ val parent = owner.parent
+ if (parent is IrFunction) {
+ if (parent.dispatchReceiverParameter == owner) {
+ return CodeFragmentCapturedValue.ContainingClass(receiverClassId, isCrossingInlineBounds = true)
+ }
+
+ return CodeFragmentCapturedValue.ExtensionReceiver(parent.name.asString(), isCrossingInlineBounds = true)
+ }
+ }
+ }
+
+ if (descriptor is IrBasedVariableDescriptor && owner is IrVariable) {
+ val name = owner.name
+ val isMutated = false // TODO capture the usage somehow
+
+ if (owner.origin == IrDeclarationOrigin.PROPERTY_DELEGATE) {
+ return CodeFragmentCapturedValue.LocalDelegate(name, isMutated, isCrossingInlineBounds = true)
+ }
+
+ return CodeFragmentCapturedValue.Local(name, isMutated, isCrossingInlineBounds = true)
+ }
+
+ if (descriptor is IrBasedValueParameterDescriptor && owner is IrValueParameter) {
+ val name = owner.name
+ return CodeFragmentCapturedValue.Local(name, isMutated = false, isCrossingInlineBounds = true)
+ }
+
+ return null
+ }
+
private fun getFullyResolvedFirFile(file: KtFile): FirFile {
val firFile = file.getOrBuildFirFile(firResolveSession)
LLFirWholeFileResolveTarget(firFile).resolve(FirResolvePhase.BODY_RESOLVE)