[LL FIR, FIR] use correct KtSourceElement implementation in KtLightSourceElement.unwrapToKtPsiSourceElement

To use FirPsiSourceElementFactory, also move it closer to the single client, to LL FIR

^KT-57589
diff --git a/analysis/low-level-api-fir/src/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostics/FileStructureElementDiagnosticsCollector.kt b/analysis/low-level-api-fir/src/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostics/FileStructureElementDiagnosticsCollector.kt
index ae54ed2..a21d0da 100644
--- a/analysis/low-level-api-fir/src/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostics/FileStructureElementDiagnosticsCollector.kt
+++ b/analysis/low-level-api-fir/src/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostics/FileStructureElementDiagnosticsCollector.kt
@@ -8,6 +8,7 @@
 import org.jetbrains.kotlin.fir.analysis.collectors.CheckerRunningDiagnosticCollectorVisitor
 import org.jetbrains.kotlin.fir.declarations.FirDeclaration
 import org.jetbrains.kotlin.analysis.low.level.api.fir.diagnostics.fir.LLFirStructureElementDiagnosticsCollector
+import org.jetbrains.kotlin.analysis.low.level.api.fir.sessions.llFirSession
 import org.jetbrains.kotlin.fir.analysis.collectors.DiagnosticCollectorComponents
 
 internal class FileStructureElementDiagnosticsCollector private constructor(private val useExtendedCheckers: Boolean) {
@@ -20,7 +21,7 @@
         firDeclaration: FirDeclaration,
         createVisitor: (components: DiagnosticCollectorComponents) -> CheckerRunningDiagnosticCollectorVisitor,
     ): FileStructureElementDiagnosticList {
-        val reporter = LLFirDiagnosticReporter()
+        val reporter = LLFirDiagnosticReporter(firDeclaration.llFirSession)
         val collector = LLFirStructureElementDiagnosticsCollector(
             firDeclaration.moduleData.session,
             createVisitor,
diff --git a/analysis/low-level-api-fir/src/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostics/LLFirDiagnosticReporter.kt b/analysis/low-level-api-fir/src/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostics/LLFirDiagnosticReporter.kt
index 73c3938..733e9be 100644
--- a/analysis/low-level-api-fir/src/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostics/LLFirDiagnosticReporter.kt
+++ b/analysis/low-level-api-fir/src/org/jetbrains/kotlin/analysis/low/level/api/fir/diagnostics/LLFirDiagnosticReporter.kt
@@ -7,10 +7,15 @@
 
 import com.intellij.psi.PsiElement
 import org.jetbrains.kotlin.AbstractKtSourceElement
+import org.jetbrains.kotlin.KtLightSourceElement
+import org.jetbrains.kotlin.KtPsiSourceElement
+import org.jetbrains.kotlin.WrappedTreeStructure
 import org.jetbrains.kotlin.analysis.low.level.api.fir.util.addValueFor
 import org.jetbrains.kotlin.diagnostics.*
+import org.jetbrains.kotlin.fir.FirSession
+import org.jetbrains.kotlin.fir.builder.toKtPsiSourceElement
 
-internal class LLFirDiagnosticReporter : DiagnosticReporter() {
+internal class LLFirDiagnosticReporter(private val session: FirSession) : DiagnosticReporter() {
     private val pendingDiagnostics = mutableMapOf<PsiElement, MutableList<KtPsiDiagnostic>>()
     val committedDiagnostics = mutableMapOf<PsiElement, MutableList<KtPsiDiagnostic>>()
 
@@ -20,7 +25,7 @@
 
         val psiDiagnostic = when (diagnostic) {
             is KtPsiDiagnostic -> diagnostic
-            is KtLightDiagnostic -> diagnostic.toPsiDiagnostic()
+            is KtLightDiagnostic -> diagnostic.toPsiDiagnostic(session)
             else -> error("Unknown diagnostic type ${diagnostic::class.simpleName}")
         }
         pendingDiagnostics.addValueFor(psiDiagnostic.psiElement, psiDiagnostic)
@@ -51,8 +56,8 @@
     }
 }
 
-private fun KtLightDiagnostic.toPsiDiagnostic(): KtPsiDiagnostic {
-    val psiSourceElement = element.unwrapToKtPsiSourceElement()
+private fun KtLightDiagnostic.toPsiDiagnostic(session: FirSession): KtPsiDiagnostic {
+    val psiSourceElement = element.unwrapToKtPsiSourceElement(session)
         ?: error("Diagnostic should be created from PSI in IDE")
     @Suppress("UNCHECKED_CAST")
     return when (this) {
@@ -97,3 +102,17 @@
         else -> error("Unknown diagnostic type ${this::class.simpleName}")
     }
 }
+
+/**
+ * We can create a [KtLightSourceElement] from a [KtPsiSourceElement] by using [KtPsiSourceElement.lighterASTNode];
+ * [unwrapToKtPsiSourceElement] allows to get original [KtPsiSourceElement] in such case.
+ *
+ * If it is `pure` [KtLightSourceElement], i.e, compiler created it in light tree mode, then return [unwrapToKtPsiSourceElement] `null`.
+ * Otherwise, return some not-null result.
+ */
+private fun KtLightSourceElement.unwrapToKtPsiSourceElement(session: FirSession): KtPsiSourceElement? {
+    val treeStructure = treeStructure
+    if (treeStructure !is WrappedTreeStructure) return null
+    val node = treeStructure.unwrap(lighterASTNode)
+    return node.psi?.toKtPsiSourceElement(session, kind)
+}
diff --git a/compiler/frontend.common/src/org/jetbrains/kotlin/KtSourceElement.kt b/compiler/frontend.common/src/org/jetbrains/kotlin/KtSourceElement.kt
index c92d116..4bc08ba 100644
--- a/compiler/frontend.common/src/org/jetbrains/kotlin/KtSourceElement.kt
+++ b/compiler/frontend.common/src/org/jetbrains/kotlin/KtSourceElement.kt
@@ -327,19 +327,6 @@
     override val elementType: IElementType
         get() = lighterASTNode.tokenType
 
-    /**
-     * We can create a [KtLightSourceElement] from a [KtPsiSourceElement] by using [KtPsiSourceElement.lighterASTNode];
-     * [unwrapToKtPsiSourceElement] allows to get original [KtPsiSourceElement] in such case.
-     *
-     * If it is `pure` [KtLightSourceElement], i.e, compiler created it in light tree mode, then return [unwrapToKtPsiSourceElement] `null`.
-     * Otherwise, return some not-null result.
-     */
-    fun unwrapToKtPsiSourceElement(): KtPsiSourceElement? {
-        if (treeStructure !is WrappedTreeStructure) return null
-        val node = treeStructure.unwrap(lighterASTNode)
-        return node.psi?.toKtPsiSourceElementWithFixedPsi(kind)
-    }
-
     override fun equals(other: Any?): Boolean {
         if (this === other) return true
         if (javaClass != other?.javaClass) return false