refactor: unnested _py_package_impl logic Signed-off-by: Thulio Ferraz Assis <3149049+f0rmiga@users.noreply.github.com>
diff --git a/python/private/py_package.bzl b/python/private/py_package.bzl index 7975a8d..ec8126b 100644 --- a/python/private/py_package.bzl +++ b/python/private/py_package.bzl
@@ -41,26 +41,12 @@ filtered_inputs = input_files imports = [dep[PyInfo].imports for dep in ctx.attr.deps] else: + input_files_list = input_files.to_list() filtered_files = [] for dep in ctx.attr.deps: - add_imports = False - - # TODO: flattening depset to list gives poor performance, - for input_file in input_files.to_list(): - path_inside_wheel = _path_inside_wheel(input_file) - for package in packages: - if path_inside_wheel.startswith(package) and not exclude: - filtered_files.append(input_file) - if path_inside_wheel.endswith(".py"): - add_imports = True - else: - for excluded in exclude: - if path_inside_wheel.startswith(package) and not path_inside_wheel.startswith(excluded): - filtered_files.append(input_file) - if path_inside_wheel.endswith(".py"): - add_imports = True - if add_imports: - imports.append(dep[PyInfo].imports) + (dep_imports, dep_filtered_files) = _py_package_add_imports(dep, input_files_list, packages, exclude) + imports.append(dep_imports) + filtered_files.extend(dep_filtered_files) filtered_inputs = depset(direct = filtered_files) return [ @@ -73,6 +59,46 @@ ), ] +def _py_package_add_imports(dep, input_files, packages, exclude): + """Returns the imports and filtered files for the given dep based on the packages and exclude list. + + Args: + dep: Dependency to add imports for. + input_files: List of input files. + packages: List of packages to include. + exclude: List of packages to exclude. + """ + filtered_files = [] + add_imports = False + + for input_file in input_files: + path_inside_wheel = _path_inside_wheel(input_file) + for package in packages: + if _py_package_should_include_file(path_inside_wheel, package, exclude): + filtered_files.append(input_file) + add_imports = add_imports or path_inside_wheel.endswith(".py") + + if add_imports: + return (dep[PyInfo].imports, filtered_files) + + return ([], filtered_files) + +def _py_package_should_include_file(path_inside_wheel, package, exclude): + """Returns true if the file should be included in the wheel based on the package and exclude list. + + Args: + path_inside_wheel: Path of the file inside the wheel. + package: Package name to include. + exclude: List of packages to exclude. + """ + if path_inside_wheel.startswith(package): + if not exclude: + return True + for excluded in exclude: + if not path_inside_wheel.startswith(excluded): + return True + return False + py_package_lib = struct( implementation = _py_package_impl, attrs = {