Merge duplicated virtual methods in compiler bytecode - IR
diff --git a/prepare/compiler-bytecode-postprocessor/build.gradle.kts b/prepare/compiler-bytecode-postprocessor/build.gradle.kts
new file mode 100644
index 0000000..d781ca8
--- /dev/null
+++ b/prepare/compiler-bytecode-postprocessor/build.gradle.kts
@@ -0,0 +1,12 @@
+plugins {
+ kotlin("jvm")
+}
+
+repositories {
+ mavenCentral()
+}
+
+dependencies {
+ api(kotlinStdlib())
+ implementation(commonDependency("org.jetbrains.intellij.deps:asm-all"))
+}
\ No newline at end of file
diff --git a/prepare/compiler-bytecode-postprocessor/src/main/kotlin/org/jetbrains/kotlin/compiler/bytecodePostprocessor/main.kt b/prepare/compiler-bytecode-postprocessor/src/main/kotlin/org/jetbrains/kotlin/compiler/bytecodePostprocessor/main.kt
new file mode 100644
index 0000000..cf141be
--- /dev/null
+++ b/prepare/compiler-bytecode-postprocessor/src/main/kotlin/org/jetbrains/kotlin/compiler/bytecodePostprocessor/main.kt
@@ -0,0 +1,423 @@
+package org.jetbrains.kotlin.compiler.bytecodePostprocessor
+
+import org.jetbrains.org.objectweb.asm.ClassReader
+import org.jetbrains.org.objectweb.asm.ClassWriter
+import org.jetbrains.org.objectweb.asm.Label
+import org.jetbrains.org.objectweb.asm.Opcodes
+import org.jetbrains.org.objectweb.asm.Type
+import org.jetbrains.org.objectweb.asm.tree.*
+import java.util.zip.ZipEntry
+import java.util.zip.ZipInputStream
+import java.util.zip.ZipOutputStream
+import kotlin.and
+import kotlin.inv
+import kotlin.io.path.Path
+import kotlin.io.path.inputStream
+import kotlin.io.path.outputStream
+import kotlin.or
+import kotlin.properties.ReadWriteProperty
+import kotlin.reflect.KProperty
+import kotlin.text.toLong
+
+private val allClasses = mutableMapOf<String, ClassInfo>()
+
+fun main(args: Array<String>) {
+ val (inputJarFileName, classPathDense, outputFileName, processClassPatternsDense) = args
+
+ loadAllClasses(
+ inputJarFileName,
+ classPathDense.split(';'),
+ processClassPatternsDense.let { if (it.isEmpty()) emptyList() else it.split(';') }
+ )
+
+ val processedClasses = mutableSetOf<ClassInfo>()
+ for (clazz in allClasses.values) {
+ processClass(clazz, processedClasses)
+ }
+
+ saveResult(outputFileName)
+}
+
+
+private fun loadAllClasses(inputJarFileName: String, classPath: List<String>, processClassPatterns: List<String>) {
+ val includeClassPatterns = processClassPatterns.map { it.trimEnd('*') }
+ ZipInputStream(Path(inputJarFileName).inputStream().buffered()).use { inputJar ->
+ while (true) {
+ val entry = inputJar.getNextEntry() ?: break
+ if (entry.isDirectory) continue
+ if (entry.name.endsWith(".class")) {
+ val internalName = entry.name.removeSuffix(".class")
+ val packageName = internalName.replace('/', '.')
+
+ val include = includeClassPatterns.isEmpty() || includeClassPatterns.any { pattern ->
+ if (pattern.endsWith('.')) packageName.startsWith(pattern)
+ else packageName == pattern
+ }
+
+ var clazz: ClassNode? = null
+ if (include) {
+ val cr = ClassReader(inputJar.buffered())
+ clazz = ClassNode()
+ cr.accept(clazz, 0)
+ }
+
+ val classInfo = ClassInfo(internalName, include, clazz, entry)
+ allClasses[internalName] = classInfo
+
+ if (clazz != null) {
+ clazz.methods.forEach { it.declaringClass = classInfo }
+ }
+ }
+
+ inputJar.closeEntry()
+ }
+ }
+
+ markSubclasses()
+}
+
+private fun markSubclasses() {
+ for (clazzInfo in allClasses.values) {
+ clazzInfo.classNode?.let { clazz ->
+ for (superName in clazz.superTypeNames) {
+ allClasses[superName]?.directSubclasses += clazzInfo
+ }
+ }
+ }
+}
+
+private fun saveResult(outputFileName: String) {
+ ZipOutputStream(Path(outputFileName).outputStream().buffered()).use { outputStream ->
+ for (clazz in allClasses.values) {
+ if (!clazz.isApplicationClass) continue
+
+ val cw = ClassWriter(0)
+ clazz.classNode!!.accept(cw)
+ val binary = cw.toByteArray()
+
+ val zipEntry = ZipEntry(clazz.zipEntry.name).apply {
+ size = binary.size.toLong()
+ method = clazz.zipEntry.method
+ }
+ outputStream.putNextEntry(zipEntry)
+ outputStream.write(binary)
+ outputStream.closeEntry()
+ }
+ }
+}
+
+private fun processClass(clazz: ClassInfo, processedClasses: MutableSet<ClassInfo>) {
+ if (!processedClasses.add(clazz)) return
+ val clazzNode = clazz.classNode ?: return
+
+ if (clazzNode.access and Opcodes.ACC_INTERFACE != 0 || clazzNode.access and Opcodes.ACC_FINAL != 0) return
+ val subClasses = clazz.directSubclasses
+ if (subClasses.size < 1) return
+ if (!subClasses.all { it.isApplicationClass }) return
+
+ for (subClass in subClasses) {
+ processClass(subClass, processedClasses)
+ }
+
+ val hoistableFields = subClasses.flatMap { it.classNode!!.fields }
+ .filterNot { it.access and Opcodes.ACC_STATIC != 0 }
+ //.filter { it.isPrivate || it.isProtected }
+ .groupBy { it.name }
+ .filterValues { it.size == subClasses.size }
+ .filterValues { fields -> fields.map { it.desc }.distinct().size == 1 }
+
+ val allMethods = mutableMapOf<String, MethodNode>()
+ fun collect(ancestor: ClassInfo) {
+ if (ancestor !== clazz && !ancestor.isApplicationClass) return
+ val ancestorNode = ancestor.classNode!!
+ ancestorNode.methods.forEach {
+ allMethods.putIfAbsent(it.name + it.desc, it)
+ }
+ ancestorNode.superTypeNames
+ .mapNotNull { allClasses[it] }
+ .forEach { collect(it) }
+ }
+ collect(clazz)
+
+ val hoistedFields = mutableMapOf<String, FieldNode>()
+ for (baseMethod in allMethods.values) {
+ tryHoistMethod(clazz, baseMethod, subClasses, hoistableFields.keys, hoistedFields)
+ }
+
+ for (name in hoistedFields.keys) {
+ for (subClass in subClasses) {
+ subClass.classNode!!.fields.removeIf { it.name == name }
+ }
+ }
+}
+
+private fun tryHoistMethod(
+ clazz: ClassInfo,
+ baseMethod: MethodNode,
+ subClasses: List<ClassInfo>,
+ hoistableFields: Set<String>,
+ hoistedFields: MutableMap<String, FieldNode>,
+) {
+ if (baseMethod.access and Opcodes.ACC_ABSTRACT == 0 || baseMethod.access and Opcodes.ACC_PRIVATE != 0) return
+
+ val implMethods = subClasses.mapNotNull { subClass ->
+ subClass.classNode!!.methods.singleOrNull { it.name == baseMethod.name && it.desc == baseMethod.desc }
+ }
+ if (implMethods.size != subClasses.size || !implMethods.all { it.access and Opcodes.ACC_ABSTRACT == 0 }) return
+
+ var firstImpl: MethodNode? = null
+ for (impl in implMethods) {
+ if (firstImpl == null) {
+ firstImpl = impl
+ } else if (!compareMethods(firstImpl, impl)) {
+ return
+ }
+ }
+
+ if (!checkCanHoist(firstImpl!!, clazz, hoistableFields)) {
+ return
+ }
+ println(
+ "Hoist method ${clazz.className}.${baseMethod.name.substringAfterLast('.')} from " +
+ "[${subClasses.joinToString { it.className.substringAfterLast('.') }}]"
+ )
+
+ hoistMethod(firstImpl, baseMethod, clazz, implMethods, hoistableFields, hoistedFields)
+ for ((subClass, impl) in subClasses zip implMethods) {
+ (subClass.classNode!!.methods as MutableList) -= impl
+ }
+}
+
+private fun compareMethods(aMethod: MethodNode, bMethod: MethodNode): Boolean {
+ val aInstructions = aMethod.instructions
+ val bInstructions = bMethod.instructions
+
+ infix fun LabelNode.equals(other: LabelNode): Boolean =
+ aInstructions.indexOf(this) == bInstructions.indexOf(other)
+
+ fun Iterable<AbstractInsnNode>.filterReal() = asSequence().filterNot { it is LabelNode || it is LineNumberNode || it is FrameNode }
+ for ((a, b) in aInstructions.filterReal() zip bInstructions.filterReal()) {
+ if (a.opcode != b.opcode) return false
+ when (a) {
+ is LabelNode, is LineNumberNode, is FrameNode -> {}
+ is InsnNode -> {}
+ is VarInsnNode -> {
+ b as VarInsnNode
+ if (a.`var` != b.`var`) return false
+ }
+ is IntInsnNode -> {
+ b as IntInsnNode
+ if (a.operand != b.operand) return false
+ }
+ is IincInsnNode -> {
+ b as IincInsnNode
+ if (a.`var` != b.`var`) return false
+ if (a.incr != b.incr) return false
+ }
+ is TypeInsnNode -> {
+ b as TypeInsnNode
+ if (a.desc != b.desc) return false
+ }
+ is MultiANewArrayInsnNode -> {
+ b as MultiANewArrayInsnNode
+ if (a.desc != b.desc) return false
+ if (a.dims != b.dims) return false
+ }
+ is LdcInsnNode -> {
+ b as LdcInsnNode
+ if (a.cst != b.cst) return false
+ }
+ is JumpInsnNode -> {
+ b as JumpInsnNode
+ if (!(a.label equals b.label)) return false
+ }
+ is TableSwitchInsnNode -> {
+ b as TableSwitchInsnNode
+ if (a.min != b.min) return false
+ if (a.max != b.max) return false
+ if ((a.labels zip b.labels).all { it.first equals it.second }) return false
+ if (!(a.dflt equals b.dflt)) return false
+ }
+ is LookupSwitchInsnNode -> {
+ b as LookupSwitchInsnNode
+ if (a.keys != b.keys) return false
+ if ((a.labels zip b.labels).all { it.first equals it.second }) return false
+ if (!(a.dflt equals b.dflt)) return false
+ }
+ is MethodInsnNode -> {
+ b as MethodInsnNode
+ if (a.owner != b.owner) return false
+ if (a.name != b.name) return false
+ if (a.desc != b.desc) return false
+ }
+ is InvokeDynamicInsnNode -> {
+ b as InvokeDynamicInsnNode
+ if (a.name != b.name) return false
+ if (a.desc != b.desc) return false
+ if (a.bsm != b.bsm) return false
+ if (a.bsmArgs != b.bsmArgs) return false
+ }
+ is FieldInsnNode -> {
+ b as FieldInsnNode
+ if (a.opcode == Opcodes.GETFIELD || a.opcode == Opcodes.PUTFIELD) {
+ if (a.name != b.name) return false
+ if (a.desc != b.desc) return false
+ if (a.owner != b.owner) {
+ if (a.owner != aMethod.declaringClass.name ||
+ b.owner != bMethod.declaringClass.name
+ ) return false
+ }
+ } else {
+ if (a.owner != b.owner) return false
+ if (a.name != b.name) return false
+ if (a.desc != b.desc) return false
+ }
+ }
+ }
+ }
+
+ return true
+}
+
+private fun checkCanHoist(source: MethodNode, targetClass: ClassInfo, hoistableFields: Set<String>): Boolean {
+ val instructions = source.instructions
+ //if (instructions.size() > 10) return false
+
+ for (inst in instructions) {
+ when (inst) {
+ is TypeInsnNode -> {
+ //val type = inst.desc.trimStart('[')
+ //if (!hierarchy.isVisible(targetClass, type.sootClass)) return false
+ }
+ is MultiANewArrayInsnNode -> {
+ //val type = inst.baseType.baseType
+ //if (type is RefType && !hierarchy.isVisible(targetClass, type.sootClass)) return false
+ }
+ is InvokeDynamicInsnNode -> return false
+ is MethodInsnNode -> {
+ if (inst.owner == source.declaringClass.name) return false
+ //if (!hierarchy.isVisible(targetClass, inst.method)) return false
+ }
+ is FieldInsnNode -> {
+ if (inst.opcode == Opcodes.GETFIELD || inst.opcode == Opcodes.PUTFIELD) {
+ if (inst.owner == source.declaringClass.name) {
+ if (inst.name !in hoistableFields) return false
+ } else {
+ //if (!hierarchy.isVisible(targetClass, inst.field)) return false
+ }
+ } else {
+ //if (!hierarchy.isVisible(targetClass, inst.field)) return false
+ }
+ }
+ }
+ }
+
+ return true
+}
+
+
+private fun hoistMethod(
+ source: MethodNode,
+ baseMethod: MethodNode,
+ targetClass: ClassInfo,
+ allSourceMethods: List<MethodNode>,
+ hoistableFields: Set<String>,
+ hoistedFields: MutableMap<String, FieldNode>,
+) {
+ val target: MethodNode
+ if (baseMethod.declaringClass == targetClass) {
+ target = baseMethod
+ } else {
+ target = MethodNode(baseMethod.access, baseMethod.name, baseMethod.desc, baseMethod.signature, baseMethod.exceptions.toTypedArray())
+ targetClass.classNode!!.methods.add(target)
+ target.declaringClass = targetClass
+ }
+
+ target.access = target.access and Opcodes.ACC_ABSTRACT.inv()
+ if (allSourceMethods.all { it.access and Opcodes.ACC_FINAL != 0 || it.access and Opcodes.ACC_FINAL != 0 }) { // condition may be better
+ target.access = target.access or Opcodes.ACC_FINAL
+ }
+
+ val sourceBody = source.instructions
+ val targetBody = target.instructions
+ target.maxStack = source.maxStack
+ target.maxLocals = source.maxLocals
+
+ val clonedLabels = buildMap {
+ for (inst in sourceBody) {
+ if(inst is LabelNode) {
+ put(inst, LabelNode(Label()))
+ }
+ }
+ }
+ for (inst in sourceBody) {
+ inst.clone(clonedLabels)?.let {
+ if (it !is LineNumberNode) {
+ targetBody.add(it)
+ }
+ }
+ }
+
+ for (inst in targetBody) {
+ if (inst.opcode == Opcodes.GETFIELD || inst.opcode == Opcodes.PUTFIELD) {
+ inst as FieldInsnNode
+ if (inst.owner == source.declaringClass.name) {
+ check(inst.name in hoistableFields) { "Field ${inst.name} cannot be hoisted" }
+
+ val field = inst.resolve()
+ val newField = hoistedFields.computeIfAbsent(inst.name) {
+ val access = mergeModifierVisibilities(field.access, Opcodes.ACC_PROTECTED)
+ FieldNode(access, field.name, field.desc, field.signature, field.value).also {
+ targetClass.classNode!!.fields.add(it)
+ }
+ }
+
+ newField.access = mergeModifierVisibilities(newField.access, field.access)
+ if (true /*!field.isFinal*/) {
+ // Final fields cannot be written from a subclass
+ newField.access = newField.access and Opcodes.ACC_FINAL.inv()
+ }
+
+ inst.owner = targetClass.name
+ }
+ }
+ }
+}
+
+private fun mergeModifierVisibilities(modifiers: Int, visibility: Int): Int {
+ val AllVisibilities = Opcodes.ACC_PRIVATE or Opcodes.ACC_PROTECTED or Opcodes.ACC_PUBLIC
+ val otherModifiers = modifiers and AllVisibilities.inv()
+ return when {
+ modifiers and Opcodes.ACC_PUBLIC != 0 || visibility and Opcodes.ACC_PUBLIC != 0 -> otherModifiers or Opcodes.ACC_PUBLIC
+ modifiers and AllVisibilities == 0 || visibility and AllVisibilities == 0 -> otherModifiers
+ modifiers and Opcodes.ACC_PROTECTED != 0 || visibility and Opcodes.ACC_PROTECTED != 0 -> otherModifiers or Opcodes.ACC_PROTECTED
+ else -> otherModifiers or Opcodes.ACC_PRIVATE
+ }
+}
+
+private class ClassInfo(
+ val name: String,
+ val isApplicationClass: Boolean,
+ val classNode: ClassNode?,
+ val zipEntry: ZipEntry,
+) {
+ val directSubclasses = mutableListOf<ClassInfo>()
+ val type: Type get() = Type.getObjectType(name)
+ val className: String get() = type.className
+
+ override fun toString(): String = name
+}
+
+private val ClassNode.superTypeNames: List<String?>
+ get() = listOfNotNull(superName) + interfaces
+
+private var MethodNode.declaringClass: ClassInfo by object : ReadWriteProperty<MethodNode, ClassInfo> {
+ private val map = hashMapOf<MethodNode, ClassInfo>()
+ override fun getValue(thisRef: MethodNode, property: KProperty<*>): ClassInfo = map.getValue(thisRef)
+ override fun setValue(thisRef: MethodNode, property: KProperty<*>, value: ClassInfo) {
+ map[thisRef] = value
+ }
+}
+
+private fun FieldInsnNode.resolve() =
+ allClasses.getValue(owner).classNode!!.fields.first { it.name == name }
diff --git a/prepare/compiler/build.gradle.kts b/prepare/compiler/build.gradle.kts
index e8b12b1..6a3288a 100644
--- a/prepare/compiler/build.gradle.kts
+++ b/prepare/compiler/build.gradle.kts
@@ -1,5 +1,7 @@
@file:Suppress("HasPlatformType")
+import org.gradle.api.tasks.TaskProvider
+import org.gradle.api.tasks.bundling.Jar
import org.gradle.internal.jvm.Jvm
import java.util.regex.Pattern.quote
@@ -346,7 +348,48 @@
printconfiguration(layout.buildDirectory.file("compiler.pro.dump"))
}
-val pack: TaskProvider<out DefaultTask> = if (kotlinBuildProperties.proguard) proguard else packCompiler
+val postprocessorClasspath by configurations.creating
+dependencies {
+ postprocessorClasspath(project(":prepare:compiler-bytecode-postprocessor"))
+}
+
+val runPostprocessing by tasks.registering(NoDebugJavaExec::class) {
+ dependsOn(proguard)
+
+ val inputFile = proguard.map { it.singleOutputFile(layout) }
+ val outputFile = layout.buildDirectory.file("libs/$compilerBaseName-after-postprocessing.jar")
+
+ inputs.file(inputFile)
+ outputs.file(outputFile)
+
+ javaLauncher.set(project.getToolchainLauncherFor(JdkMajorVersion.JDK_1_8))
+ classpath = postprocessorClasspath
+ mainClass = "org.jetbrains.kotlin.compiler.bytecodePostprocessor.MainKt"
+ args(
+ inputFile.get(),
+ proguardLibraries.files.joinToString(";"),
+ outputFile.get(),
+ listOf("org.jetbrains.kotlin.ir.*", "org.jetbrains.kotlin.fir.lazy.*").joinToString(";")
+ )
+}
+
+val mergePostprocessedJar by task<Jar> {
+ duplicatesStrategy = DuplicatesStrategy.EXCLUDE
+ destinationDirectory.set(layout.buildDirectory.dir("libs"))
+ archiveClassifier.set("postprocessed-merged")
+
+ dependsOn(runPostprocessing)
+ from {
+ runPostprocessing.map { zipTree(it.singleOutputFile(layout)) }
+ }
+
+ dependsOn(proguard)
+ from {
+ proguard.map { zipTree(it.singleOutputFile(layout)) }
+ }
+}
+
+val pack: TaskProvider<out DefaultTask> = if (kotlinBuildProperties.proguard) mergePostprocessedJar else packCompiler
val distDir: String by rootProject.extra
val jar = runtimeJar {
diff --git a/settings.gradle b/settings.gradle
index a8288bf..795e043 100644
--- a/settings.gradle
+++ b/settings.gradle
@@ -951,4 +951,6 @@
include ':native:kotlin-test-native-xctest'
include ':native:cli-native'
include ':native:native.tests:cli-tests'
-}
\ No newline at end of file
+}
+
+include 'prepare:compiler-bytecode-postprocessor'