Migrate AppendJava8LegacyDex and BuildLegacyDex action from native to Starlark

PiperOrigin-RevId: 535530930
Change-Id: I5898164105a2f235ef51a6de1193abb1ee71f468
diff --git a/rules/android_binary_internal/impl.bzl b/rules/android_binary_internal/impl.bzl
index 0bf92eb..72d2b7d 100644
--- a/rules/android_binary_internal/impl.bzl
+++ b/rules/android_binary_internal/impl.bzl
@@ -196,16 +196,25 @@
 
 def _process_dex(ctx, packaged_resources_ctx, jvm_ctx, deploy_ctx, **_unused_ctxs):
     providers = []
+    classes_dex_zip = None
+    final_classes_dex_zip = None
+    deploy_jar = deploy_ctx.deploy_jar
+    is_binary_optimized = len(ctx.attr.proguard_specs) > 0
 
     if acls.in_android_binary_starlark_dex_desugar_proguard(str(ctx.label)):
         runtime_jars = jvm_ctx.java_info.runtime_output_jars + [packaged_resources_ctx.class_jar]
         forbidden_dexopts = ctx.fragments.android.get_target_dexopts_that_prevent_incremental_dexing
-        classes_dex_zip = None
+        java8_legacy_dex, java8_legacy_dex_map = _dex.get_java8_legacy_dex_and_map(
+            ctx,
+            android_jar = ctx.attr._android_sdk[AndroidSdkInfo].android_jar,
+            binary_jar = deploy_jar,
+            build_customized_files = is_binary_optimized,
+        )
 
         incremental_dexing = _dex.get_effective_incremental_dexing(
             force_incremental_dexing = ctx.attr.incremental_dexing,
             has_forbidden_dexopts = len([d for d in ctx.attr.dexopts if d in forbidden_dexopts]) > 0,
-            is_binary_optimized = len(ctx.attr.proguard_specs) > 0,
+            is_binary_optimized = is_binary_optimized,
             incremental_dexing_after_proguard_by_default = ctx.fragments.android.incremental_dexing_after_proguard_by_default,
             incremental_dexing_shards_after_proguard = ctx.fragments.android.incremental_dexing_shards_after_proguard,
             use_incremental_dexing = ctx.fragments.android.use_incremental_dexing,
@@ -226,11 +235,23 @@
                 dexmerger = get_android_toolchain(ctx).dexmerger.files_to_run,
             )
 
+        if ctx.fragments.android.desugar_java8_libs and classes_dex_zip.extension == "zip":
+            final_classes_dex_zip = _dex.get_dx_artifact(ctx, "final_classes_dex_zip")
+            _dex.append_java8_legacy_dex(
+                ctx,
+                output = final_classes_dex_zip,
+                input = classes_dex_zip,
+                java8_legacy_dex = java8_legacy_dex,
+                dex_zips_merger = get_android_toolchain(ctx).dex_zips_merger.files_to_run,
+            )
+        else:
+            final_classes_dex_zip = classes_dex_zip
+
         providers.append(
             AndroidDexInfo(
-                deploy_jar = deploy_ctx.deploy_jar,
-                final_classes_dex_zip = classes_dex_zip,
-                java_resource_jar = deploy_ctx.deploy_jar,
+                deploy_jar = deploy_jar,
+                final_classes_dex_zip = final_classes_dex_zip,
+                java_resource_jar = deploy_jar,
             ),
         )
 
diff --git a/rules/dex.bzl b/rules/dex.bzl
index 52df692..cc0b44d 100644
--- a/rules/dex.bzl
+++ b/rules/dex.bzl
@@ -14,7 +14,7 @@
 
 """Bazel Dex Commands."""
 
-load(":utils.bzl", "utils")
+load(":utils.bzl", "get_android_toolchain", "utils")
 load(":providers.bzl", "StarlarkAndroidDexInfo")
 load("@bazel_skylib//lib:collections.bzl", "collections")
 load("//rules:attrs.bzl", _attrs = "attrs")
@@ -67,6 +67,29 @@
 
     return classes_dex_zip
 
+def _append_java8_legacy_dex(
+        ctx,
+        output = None,
+        input = None,
+        java8_legacy_dex = None,
+        dex_zips_merger = None):
+    args = ctx.actions.args()
+
+    # Order matters here: we want java8_legacy_dex to be the highest-numbered classesN.dex
+    args.add("--input_zip", input)
+    args.add("--input_zip", java8_legacy_dex)
+    args.add("--output_zip", output)
+
+    ctx.actions.run(
+        executable = dex_zips_merger,
+        inputs = [input, java8_legacy_dex],
+        outputs = [output],
+        arguments = [args],
+        mnemonic = "AppendJava8LegacyDex",
+        use_default_shell_env = True,
+        progress_message = "Adding Java8 legacy library for %s" % ctx.label,
+    )
+
 def _to_dexed_classpath(dex_archives_dict = {}, classpath = [], runtime_jars = []):
     dexed_classpath = []
     for jar in classpath:
@@ -150,6 +173,32 @@
     # use_incremental_dexing config flag will take effect if incremental_dexing attr is not set
     return use_incremental_dexing
 
+def _get_java8_legacy_dex_and_map(ctx, build_customized_files = False, binary_jar = None, android_jar = None):
+    if not build_customized_files:
+        return utils.only(get_android_toolchain(ctx).java8_legacy_dex.files.to_list()), None
+    else:
+        java8_legacy_dex_rules = _get_dx_artifact(ctx, "_java8_legacy.dex.pgcfg")
+        java8_legacy_dex_map = _get_dx_artifact(ctx, "_java8_legacy.dex.map")
+        java8_legacy_dex = _get_dx_artifact(ctx, "_java8_legacy.dex.zip")
+
+        args = ctx.actions.args()
+        args.add("--rules", java8_legacy_dex_rules)
+        args.add("--binary", binary_jar)
+        args.add("--android_jar", android_jar)
+        args.add("--output", java8_legacy_dex)
+        args.add("--output_map", java8_legacy_dex_map)
+
+        ctx.actions.run(
+            executable = get_android_toolchain(ctx).build_java8_legacy_dex.files_to_run,
+            inputs = [binary_jar, android_jar],
+            outputs = [java8_legacy_dex_rules, java8_legacy_dex_map, java8_legacy_dex],
+            arguments = [args],
+            mnemonic = "BuildLegacyDex",
+            progress_message = "Building Java8 legacy library for %s" % ctx.label,
+        )
+
+        return java8_legacy_dex, java8_legacy_dex_map
+
 def _dex_merge(
         ctx,
         output = None,
@@ -206,10 +255,12 @@
     return collections.uniq(sorted([_dx_to_dexbuilder(token) for token in tokenized_dexopts]))
 
 dex = struct(
+    append_java8_legacy_dex = _append_java8_legacy_dex,
     dex = _dex,
     dex_merge = _dex_merge,
     get_dx_artifact = _get_dx_artifact,
     get_effective_incremental_dexing = _get_effective_incremental_dexing,
+    get_java8_legacy_dex_and_map = _get_java8_legacy_dex_and_map,
     incremental_dexopts = _incremental_dexopts,
     merge_infos = _merge_infos,
     normalize_dexopts = _normalize_dexopts,
diff --git a/toolchains/android/toolchain.bzl b/toolchains/android/toolchain.bzl
index 641b3c5..94c4f85 100644
--- a/toolchains/android/toolchain.bzl
+++ b/toolchains/android/toolchain.bzl
@@ -222,6 +222,21 @@
         default = "@bazel_tools//tools/android:zip_filter",
         executable = True,
     ),
+    dex_zips_merger = attr.label(
+        cfg = "exec",
+        default = "@bazel_tools//tools/android:merge_dexzips",
+        executable = True,
+    ),
+    java8_legacy_dex = attr.label(
+        allow_single_file = True,
+        cfg = "exec",
+        default = "@bazel_tools//tools/android:java8_legacy_dex",
+    ),
+    build_java8_legacy_dex = attr.label(
+        cfg = "exec",
+        default = "@bazel_tools//tools/android:build_java8_legacy_dex",
+        executable = True,
+    ),
 )
 
 def _impl(ctx):