# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Utility functions for uploading a cipd package."""

from __future__ import annotations

import collections
import dataclasses
import json
import re
from typing import TYPE_CHECKING

from PB.recipe_modules.fuchsia.cipd_util.upload_manifest import (
    CIPDUploadManifest,
)

from recipe_engine import recipe_api

if TYPE_CHECKING:  # pragma: no cover
    from typing import Sequence
    from recipe_engine import config_types
    from RECIPE_MODULES.pigweed.checkout import api as checkout_api


@dataclasses.dataclass(order=True)
class Roll:
    package_name: str
    old_version: str
    new_version: str

    def message(self) -> str:
        return f'From {self.old_version}\nTo {self.new_version}'


@dataclasses.dataclass
class Commit:
    rolls: List[Roll] = dataclasses.field(default_factory=list)

    def message(self, name: str | None = None) -> str:
        rolls = sorted(self.rolls)

        if not name:
            name = ", ".join(x.package_name for x in rolls)

        result = []
        result.append(f'roll: {name}')

        if len(rolls) == 1:
            result.append('')
            result.append(rolls[0].message())

        else:
            for roll in rolls:
                result.append('')
                result.append(roll.package_name)
                result.append(roll.message())
            result.append('')

        return '\n'.join(result)

    def __bool__(self):
        return bool(self.rolls)


class CipdRollApi(recipe_api.RecipeApi):
    """Utility functions for uploading a cipd package."""

    Roll = Roll
    Commit = Commit

    def is_platform(self, part: str) -> bool:
        """Return true for platform-style strings.

        Example matches: "linux-amd64", "${platform}", "${os}-amd64", "cp38".
        Example non-matches: "clang", "amd64", "linux".
        """
        if '{' in part:
            return True

        # Match Python version indicators.
        if re.match(r'cp\d+', part):
            return True

        try:
            os, arch = part.split('-')
            return os in ('linux', 'mac', 'windows')
        except ValueError:
            return False

    def find_shared_tags(
        self,
        package_tags: dict[str, set[str]],
        tag: str,
    ) -> set[str]:
        """Attempts to find a tag shared by all packages.

        This function can be used if the intersection of the sets of tags
        associated with different-platform packages with the same 'ref' is
        empty. It finds a tag shared by all packages, with as many of them as
        possible matching 'ref'.
        """
        # Find the most common tags.  We use the sorted dict keys for
        # determinism.
        package_paths = sorted(package_tags.keys())
        counter = collections.Counter()
        for path in package_paths:
            counter.update(package_tags[path])
        most_common_tags = counter.most_common()

        with self.m.step.nest("find shared tag"):
            for tag_candidate, _ in most_common_tags:
                # There is at least one package for which the version with the
                # specified 'ref' does not have this tag. See if there exists a
                # version of this package that *does* have this tag.  If so, use
                # that version.
                updated_tags = dict()
                for package_path in package_paths:
                    if tag_candidate in package_tags[package_path]:
                        # For this package we already have a version with this
                        # tag, nothing to do.
                        continue
                    try:
                        package_data = self.m.cipd.describe(
                            package_path, tag_candidate
                        )
                    except self.m.step.StepFailure:
                        # No luck: there exists no version with this tag.
                        break
                    updated_tags[package_path] = set(
                        x.tag
                        for x in package_data.tags
                        if x.tag.startswith(tag + ':')
                    )

                else:
                    # We found a version of each package with the tag_candidate.
                    merged_tags = dict()
                    merged_tags.update(package_tags)
                    merged_tags.update(updated_tags)
                    tags = set.intersection(*merged_tags.values())
                    # Should always succeed.
                    assert len(tags) > 0
                    # Update package_tags to be consistent with the returned
                    # tags.
                    package_tags.update(updated_tags)
                    return tags

        # We failed to find any tag that meets our criteria.
        return set()

    def update_package(
        self,
        checkout_root: config_types.Path,
        pkg: str,
    ) -> Roll:
        json_path = checkout_root.joinpath(*re.split(r'[\\/]+', pkg.json_path))

        if not pkg.name:
            # Turn foo/bar/baz/${platform} and foo/bar/baz/${os=mac}-${arch}
            # into 'baz'.
            pkg.name = [
                part
                for part in pkg.spec.split('/')
                if not self.is_platform(part)
            ][-1]

        basename = self.m.path.basename(json_path)
        cipd_json = self.m.file.read_json(f'read {basename}', json_path)
        packages = cipd_json
        if isinstance(cipd_json, dict):
            packages = cipd_json['packages']
        old_version = None
        package = None
        for package in packages:
            if package['path'] == pkg.spec:
                old_version = package['tags'][0]
                break
        else:
            raise self.m.step.StepFailure(
                f"couldn't find package {pkg.spec} in {json_path}"
            )

        assert package.get('platforms'), 'platforms empty in json'
        platforms = package.get('platforms')
        base, name = pkg.spec.rstrip('/').rsplit('/', 1)
        if self.is_platform(name):
            package_paths = [f'{base}/{x}' for x in platforms]
        else:
            package_paths = [pkg.spec]

        package_tags = {}
        tags = None
        for package_path in package_paths:
            try:
                package_data = self.m.cipd.describe(package_path, pkg.ref)

            except self.m.step.StepFailure:
                # If we got here this package doesn't have the correct ref. This
                # is likely because it's a new platform for an existing package.
                # In that case ignore this platform when checking that refs
                # agree on package versions. We still need at least one platform
                # to have the ref or the checks below will fail.
                pass

            else:
                package_tags[package_path] = set(
                    x.tag
                    for x in package_data.tags
                    if x.tag.startswith(pkg.tag + ':')
                )
                if tags is None:
                    tags = set(package_tags[package_path])
                else:
                    tags.intersection_update(package_tags[package_path])

        if not tags and pkg.allow_mismatched_refs:
            # The package with the requested ref has non-overlapping tag values
            # for different platforms.  Try relaxing the requirement that all
            # packages come from the same ref, and see if this allows us to find
            # a set with shared tag values.
            tags = self.find_shared_tags(package_tags, pkg.tag)

        with self.m.step.nest('common tags') as presentation:
            presentation.step_summary_text = '\n'.join(sorted(tags))

        if not tags:
            err_lines = [f'no common tags across "{pkg.ref}" refs of packages']
            for package_path, package_tags in sorted(package_tags.items()):
                err_lines.append('')
                err_lines.append(package_path)
                for tag in package_tags:
                    err_lines.append(tag)

            raise self.m.step.StepFailure('<br>'.join(err_lines))

        # Deterministically pick one of the common tags.
        new_version = sorted(tags)[0]
        package['tags'] = [new_version]

        version_part = new_version.split(':', 1)[1]
        match = re.search(
            r'(?:\d|\b)(rc|pre|beta|alpha)(?:\d|\b)',
            version_part,
        )
        if match:
            raise self.m.step.StepFailure(
                f'found pre-release indicator {match.group(1)!r} in '
                f'{version_part!r}'
            )

        # Verify there's only one instance of each platform package with this
        # tag.
        with self.m.step.nest('check number of instances'):
            for package_path in package_paths:
                self.m.cipd.describe(package_path, new_version)

        with self.m.step.nest('new_version') as presentation:
            presentation.step_summary_text = new_version

        if old_version in tags:
            with self.m.step.nest('already up-to-date') as presentation:
                presentation.step_summary_text = (
                    'current version {} in common tags'
                ).format(old_version)
            return None

        else:
            self.m.file.write_text(
                f'write {basename}',
                json_path,
                json.dumps(cipd_json, indent=2, separators=(',', ': ')) + '\n',
            )
            return Roll(
                package_name=pkg.name,
                old_version=old_version,
                new_version=new_version,
            )
