# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Utility functions for uploading a cipd package."""

from __future__ import annotations

import collections
import dataclasses
import json
import re
from typing import TYPE_CHECKING

from PB.recipe_modules.fuchsia.cipd_util.upload_manifest import (
    CIPDUploadManifest,
)
from RECIPE_MODULES.fuchsia.roll_commit_message import (
    api as roll_commit_message_api,
)

from recipe_engine import recipe_api

if TYPE_CHECKING:  # pragma: no cover
    from typing import Sequence
    from recipe_engine import config_types
    from RECIPE_MODULES.pigweed.checkout import api as checkout_api


@dataclasses.dataclass(order=True)
class Roll(roll_commit_message_api.BaseRoll):
    package_name: str
    old_version: str
    new_version: str

    def short_name(self) -> str:
        return self.package_name

    def message_header(self, force_summary_version: bool = False) -> str:
        return self.package_name

    def message_body(
        self,
        *,
        force_summary_version: bool = False,
        escape_tags: Sequence[str] = (),
        filter_tags: Sequence[str] = (),
    ) -> str:
        return f'From {self.old_version}\nTo {self.new_version}'

    def message_footer(self, *, send_comment: bool) -> str:
        return ''


class CipdRollApi(recipe_api.RecipeApi):
    """Utility functions for uploading a cipd package."""

    Roll = Roll

    def is_platform(self, part: str) -> bool:
        """Return true for platform-style strings.

        Example matches: "linux-amd64", "${platform}", "${os}-amd64", "cp38".
        Example non-matches: "clang", "amd64", "linux".
        """
        if '{' in part:
            return True

        # Match Python version indicators.
        if re.match(r'cp\d+', part):
            return True

        try:
            os, arch = part.split('-')
            return os in ('linux', 'mac', 'windows')
        except ValueError:
            return False

    def find_shared_tags(
        self,
        package_tags: dict[str, set[str]],
        tag: str,
    ) -> set[str]:
        """Attempts to find a tag shared by all packages.

        This function can be used if the intersection of the sets of tags
        associated with different-platform packages with the same 'ref' is
        empty. It finds a tag shared by all packages, with as many of them as
        possible matching 'ref'.
        """
        # Find the most common tags.  We use the sorted dict keys for
        # determinism.
        package_paths = sorted(package_tags.keys())
        counter = collections.Counter()
        for path in package_paths:
            counter.update(package_tags[path])
        most_common_tags = counter.most_common()

        with self.m.step.nest("find shared tag"):
            for tag_candidate, _ in most_common_tags:
                # There is at least one package for which the version with the
                # specified 'ref' does not have this tag. See if there exists a
                # version of this package that *does* have this tag.  If so, use
                # that version.
                updated_tags = dict()
                for package_path in package_paths:
                    if tag_candidate in package_tags[package_path]:
                        # For this package we already have a version with this
                        # tag, nothing to do.
                        continue
                    try:
                        package_data = self.m.cipd.describe(
                            package_path, tag_candidate
                        )
                    except self.m.step.StepFailure:
                        # No luck: there exists no version with this tag.
                        break
                    updated_tags[package_path] = set(
                        x.tag
                        for x in package_data.tags
                        if x.tag.startswith(tag + ':')
                    )

                else:
                    # We found a version of each package with the tag_candidate.
                    merged_tags = dict()
                    merged_tags.update(package_tags)
                    merged_tags.update(updated_tags)
                    tags = set.intersection(*merged_tags.values())
                    # Should always succeed.
                    assert len(tags) > 0
                    # Update package_tags to be consistent with the returned
                    # tags.
                    package_tags.update(updated_tags)
                    return tags

        # We failed to find any tag that meets our criteria.
        return set()

    def update_package(
        self,
        checkout_root: config_types.Path,
        pkg: str,
    ) -> list[Roll]:
        if not pkg.name:
            # Turn foo/bar/baz/${platform} and foo/bar/baz/${os=mac}-${arch}
            # into 'baz'.
            pkg.name = [
                part
                for part in pkg.spec.split('/')
                if not self.is_platform(part)
            ][-1]

        with self.m.step.nest(pkg.name):
            json_path = checkout_root.joinpath(
                *re.split(r'[\\/]+', pkg.json_path),
            )

            basename = self.m.path.basename(json_path)
            cipd_json = self.m.file.read_json(f'read {basename}', json_path)
            packages = cipd_json
            if isinstance(cipd_json, dict):
                packages = cipd_json['packages']
            old_version = None
            package = None
            for package in packages:
                if package['path'] == pkg.spec:
                    old_version = package['tags'][0]
                    break
            else:
                raise self.m.step.StepFailure(
                    f"couldn't find package {pkg.spec} in {json_path}"
                )

            assert package.get('platforms'), 'platforms empty in json'
            platforms = package.get('platforms')
            base, name = pkg.spec.rstrip('/').rsplit('/', 1)
            if self.is_platform(name):
                package_paths = [f'{base}/{x}' for x in platforms]
            else:
                package_paths = [pkg.spec]

            package_tags = {}
            tags = None
            for package_path in package_paths:
                try:
                    package_data = self.m.cipd.describe(package_path, pkg.ref)

                except self.m.step.StepFailure:
                    # If we got here this package doesn't have the correct ref.
                    # This is likely because it's a new platform for an existing
                    # package. In that case ignore this platform when checking
                    # that refs agree on package versions. We still need at
                    # least one platform to have the ref or the checks below
                    # will fail.
                    pass

                else:
                    package_tags[package_path] = set(
                        x.tag
                        for x in package_data.tags
                        if x.tag.startswith(pkg.tag + ':')
                    )
                    if tags is None:
                        tags = set(package_tags[package_path])
                    else:
                        tags.intersection_update(package_tags[package_path])

            if not tags and pkg.allow_mismatched_refs:
                # The package with the requested ref has non-overlapping tag
                # values for different platforms.  Try relaxing the requirement
                # that all packages come from the same ref, and see if this
                # allows us to find a set with shared tag values.
                tags = self.find_shared_tags(package_tags, pkg.tag)

            with self.m.step.nest('common tags') as presentation:
                presentation.step_summary_text = '\n'.join(sorted(tags))

            if not tags:
                err_lines = [
                    f'no common tags across "{pkg.ref}" refs of packages'
                ]
                for package_path, package_tags in sorted(package_tags.items()):
                    err_lines.append('')
                    err_lines.append(package_path)
                    for tag in package_tags:
                        err_lines.append(tag)

                raise self.m.step.StepFailure('<br>'.join(err_lines))

            # Deterministically pick one of the common tags.
            new_version = sorted(tags)[0]
            package['tags'] = [new_version]

            version_part = new_version.split(':', 1)[1]
            match = re.search(
                r'(?:\d|\b)(rc|pre|beta|alpha)(?:\d|\b)',
                version_part,
            )
            if match:
                raise self.m.step.StepFailure(
                    f'found pre-release indicator {match.group(1)!r} in '
                    f'{version_part!r}'
                )

            # Verify there's only one instance of each platform package with
            # this tag.
            with self.m.step.nest('check number of instances'):
                for package_path in package_paths:
                    self.m.cipd.describe(package_path, new_version)

            with self.m.step.nest('new_version') as presentation:
                presentation.step_summary_text = new_version

            if old_version in tags:
                with self.m.step.nest('already up-to-date') as presentation:
                    presentation.step_summary_text = (
                        'current version {} in common tags'
                    ).format(old_version)
                return []

            else:
                self.m.file.write_text(
                    f'write {basename}',
                    json_path,
                    json.dumps(
                        cipd_json,
                        indent=2,
                        separators=(',', ': '),
                    ) + '\n',
                )
                return [
                    Roll(
                        package_name=pkg.name,
                        old_version=old_version,
                        new_version=new_version,
                    ),
                ]
