[FIR] make FirCorrespondingSupertypesCache to be thread safe ^KT-50244
diff --git a/compiler/fir/providers/src/org/jetbrains/kotlin/fir/types/FirCorrespondingSupertypesCache.kt b/compiler/fir/providers/src/org/jetbrains/kotlin/fir/types/FirCorrespondingSupertypesCache.kt index 7b13873..267ad71 100644 --- a/compiler/fir/providers/src/org/jetbrains/kotlin/fir/types/FirCorrespondingSupertypesCache.kt +++ b/compiler/fir/providers/src/org/jetbrains/kotlin/fir/types/FirCorrespondingSupertypesCache.kt
@@ -8,6 +8,7 @@ import org.jetbrains.kotlin.fir.FirSession import org.jetbrains.kotlin.fir.FirSessionComponent import org.jetbrains.kotlin.fir.ThreadSafeMutableState +import org.jetbrains.kotlin.fir.caches.firCachesFactory import org.jetbrains.kotlin.fir.declarations.FirClassLikeDeclaration import org.jetbrains.kotlin.fir.declarations.FirTypeParameterRefsOwner import org.jetbrains.kotlin.fir.resolve.toSymbol @@ -20,7 +21,13 @@ @ThreadSafeMutableState class FirCorrespondingSupertypesCache(private val session: FirSession) : FirSessionComponent { - private val cache = HashMap<ConeClassLikeLookupTag, Map<ConeClassLikeLookupTag, List<ConeClassLikeType>>?>(1000, 0.5f) + private val cache = + session.firCachesFactory.createCache<ConeClassLikeLookupTag, Map<ConeClassLikeLookupTag, List<ConeClassLikeType>>?, TypeCheckerState>( + initialCapacity = 1000, + loadFactor = 0.5f + ) { lookupTag, typeCheckerState -> + computeSupertypesMap(lookupTag, typeCheckerState) + } fun getCorrespondingSupertypes( type: ConeKotlinType, @@ -36,11 +43,9 @@ val lookupTag = type.lookupTag if (lookupTag == supertypeConstructor) return listOf(captureType(type, typeContext)) - if (lookupTag !in cache) { - cache[lookupTag] = computeSupertypesMap(lookupTag, typeCheckerState) - } - val resultTypes = cache[lookupTag]?.getOrDefault(supertypeConstructor, emptyList()) ?: return null + val resultTypes = + cache.getValue(lookupTag, typeCheckerState)?.getOrDefault(supertypeConstructor, emptyList()) ?: return null if (type.typeArguments.isEmpty()) return resultTypes val capturedType = captureType(type, typeContext)