# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Utility functions for rollers."""

from __future__ import annotations

import collections
import enum
import re
from typing import TYPE_CHECKING
import urllib

import attrs

from PB.recipe_modules.fuchsia.auto_roller.options import (
    Options as AutoRollerOptions,
)
from recipe_engine import recipe_api

if TYPE_CHECKING:  # pragma: no cover
    from typing import Any, Optional, Sequence
    from recipe_engine import config_types

# If we're embedding the original commit message, prepend 'Original-' to lines
# which begin with these tags.
ESCAPE_TAGS: tuple[re.Pattern | str, ...] = (
    'Bug:',
    'Fixed:',
    'Fixes:',
    'Requires:',
    'Reviewed-on:',
)

# If we're embedding the original commit message, remove lines which contain
# these tags.
FILTER_TAGS: tuple[re.Pattern | str, ...] = (
    'API-Review:',
    'Acked-by:',
    re.compile(r'^\w+-?Auto-Submit:', re.IGNORECASE),
    'Build-Errors:',
    'CC:',
    'CQ-Do-Not-Cancel-Tryjobs:',
    'Cq-Include-Trybots:',
    'Change-Id:',
    'Commit-Queue:',
    'Cq-Cl-Tag:',
    re.compile(r'Git[ -]?watcher:', re.IGNORECASE),
    'Lint:',
    'No-Docs-Update-Reason:',
    'No-Presubmit:',
    'No-Tree-Checks: true',
    'No-Try: true',
    'Presubmit-Verified:',
    re.compile(r'^\w+-?Readability-Trivial:', re.IGNORECASE),
    'Reviewed-by:',
    'Roller-URL:',
    'Signed-off-by:',
    'Testability-Review:',
    'Tested-by:',
)


def _match_tag(line: str, tag: re.Pattern | str) -> bool:
    if hasattr(tag, 'match'):
        return bool(tag.match(line))
    return line.startswith(tag)


def _sanitize_message(message: str) -> str:
    """Sanitize lines of a commit message.

    Prepend 'Original-' to lines which begin with ESCAPE_TAGS. Filter
    out lines which begin with FILTER_TAGS.
    """

    lines = message.splitlines()

    # If the first line is really long create a truncated version of it, but
    # keep the original version of the commit message around.
    if len(lines[0]) > 80:
        lines = [lines[0][0:50], ''] + lines

    return '\n'.join(
        (
            "Original-" + line
            if any((_match_tag(line, tag) for tag in ESCAPE_TAGS))
            else line
        )
        for line in lines
        if not any((_match_tag(line, tag) for tag in FILTER_TAGS))
    )


class Direction(enum.Enum):
    CURRENT = 'CURRENT'
    FORWARD = 'FORWARD'
    BACKWARD = 'BACKWARD'
    REBASE = 'REBASE'


@attrs.define(frozen=True)
class Account:
    name: str
    email: str

    def __lt__(self, other) -> bool:
        return (self.email, self.name) < (other.email, other.name)


@attrs.define
class Commit:
    hash: str
    message: str
    author: str
    owner: str
    reviewers: tuple[Account]


@attrs.define
class Roll:
    _api: recipe_api.RecipeApi
    project_name: str
    old_revision: str
    new_revision: str
    proj_dir: str
    direction: str = attrs.field()
    commits: tuple[Commit, ...] | None = None
    remote: str | None = None
    _nest_steps: bool = True

    @direction.validator
    def check(self, _, value: str) -> None:  # pragma: no cover
        if value not in Direction:
            raise ValueError(f'invalid direction: {value}')
        if value == Direction.CURRENT:
            raise ValueError('attempt to do a no-op roll')

    def __attrs_post_init__(self) -> None:
        self._set_remote()
        with self._api.context(cwd=self.proj_dir):
            if self._nest_steps:
                with self._api.step.nest(self.project_name):
                    self._set_commits()
            else:
                self._set_commits()  # pragma: no cover

    def _set_commits(self) -> None:

        log_cmd: list[str] = [
            'log',
            '--pretty=format:%H\n%an\n%ae\n%B',
            # Separate entries with null bytes since most entries
            # will contain newlines ("%B" is the full commit
            # message, not just the first line.)
            '-z',
        ]

        if _is_hash(self.old_revision) and self.direction == Direction.FORWARD:
            log_cmd.append(f'{self.old_revision}..{self.new_revision}')
        else:
            log_cmd.extend(('--max-count', '5', self.new_revision))

        log_kwargs: dict[str, Any] = {'stdout': self._api.raw_io.output_text()}

        commit_log: str = (
            self._api.git('git log', *log_cmd, **log_kwargs)
            .stdout.strip('\0')
            .split('\0')
        )

        commits: list[Commit] = []
        for i, commit in enumerate(commit_log):
            commit_hash: str
            name: str
            email: str
            message: str
            commit_hash, name, email, message = commit.split('\n', 3)
            author = Account(name, email)
            owner: Account | None = None
            reviewers = []

            full_host = f'{self.gerrit_name}-review.googlesource.com'

            changes = []

            # If there are a lot of CLs in this roll only get owner and
            # reviewer data from the first 10 so we don't make too many
            # requests of Gerrit.
            if i < 10:
                change_query_step = self._api.gerrit.change_query(
                    'get change-id',
                    f'commit:{commit_hash}',
                    host=full_host,
                    test_data=self._api.json.test_api.output(
                        [{'_number': 12345}]
                    ),
                    ok_ret='any',
                )
                if change_query_step.exc_result.retcode == 0:
                    changes = change_query_step.json.output

            if changes and len(changes) == 1:
                number = changes[0]['_number']
                step = self._api.gerrit.change_details(
                    f'get {number}',
                    number,
                    host=full_host,
                    test_data=self._api.json.test_api.output(
                        {
                            'owner': {
                                'name': 'author',
                                'email': 'author@example.com',
                            },
                            'reviewers': {
                                'REVIEWER': [
                                    {
                                        'name': 'reviewer',
                                        'email': 'reviewer@example.com',
                                    },
                                    {
                                        'name': 'nobody',
                                        'email': 'nobody@google.com',
                                    },
                                    {
                                        'name': 'robot',
                                        'email': 'robot@gserviceaccount.com',
                                    },
                                ],
                            },
                        }
                    ),
                    ok_ret='any',
                )

                if step.exc_result.retcode == 0:
                    details = step.json.output
                    owner = Account(
                        details['owner']['name'], details['owner']['email']
                    )
                    for reviewer in details['reviewers']['REVIEWER']:
                        reviewers.append(
                            Account(
                                reviewer['name'],
                                reviewer.get('email', 'robot@example.com'),
                            ),
                        )

            commits.append(
                Commit(
                    hash=commit_hash,
                    author=author,
                    owner=owner,
                    reviewers=tuple(reviewers),
                    message=message,
                )
            )

        self.commits = tuple(commits)

    def _set_remote(self) -> None:
        api = self._api

        with api.step.nest('remote'), api.context(cwd=self.proj_dir):
            # There may be multiple remote names. Only get the first one. They
            # should refer to the same URL so it doesn't matter which we use.
            name = (
                api.git(
                    'name',
                    'remote',
                    stdout=api.raw_io.output_text(),
                    step_test_data=lambda: api.raw_io.test_api.stream_output_text(
                        'origin'
                    ),
                )
                .stdout.strip()
                .split('\n')[0]
            )

            remote = api.git(
                'url',
                'remote',
                'get-url',
                name,
                stdout=api.raw_io.output_text(),
                step_test_data=lambda: api.raw_io.test_api.stream_output_text(
                    'sso://pigweed/pigweed/pigweed'
                ),
            ).stdout.strip()

            self.remote = api.sso.sso_to_https(remote)

    @property
    def gerrit_name(self) -> str:
        return urllib.parse.urlparse(self.remote).netloc.split('.')[0]


@attrs.define
class Message:
    name: str
    template: str
    kwargs: dict[str, Any]
    num_commits: int
    footer: tuple = ()

    def render(self, with_footer: bool = True) -> str:
        result = [self.template.format(**self.kwargs)]
        if with_footer:
            result.extend(x for x in self.footer)
        return '\n'.join(result)


def _is_hash(value: str) -> bool:
    return bool(re.match(r'^[0-9a-fA-F]{40}', value))


def _pprint_dict(d: dict) -> str:
    result = []
    for k, v in sorted(d.items()):
        result.append(f'{k!r}: {v!r}\n')
    return ''.join(result)


class RollUtilApi(recipe_api.RecipeApi):
    Account = Account
    Roll = Roll
    Direction = Direction

    def __init__(self, props, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.labels_to_set = collections.OrderedDict()
        for label in sorted(props.labels_to_set, key=lambda x: x.label):
            self.labels_to_set[str(label.label)] = label.value
        self.labels_to_wait_on = sorted(str(x) for x in props.labels_to_wait_on)
        self.footer = list(props.footer)
        self._commit_divider = props.commit_divider

    def authors(self, *rolls: Roll) -> set[Account]:
        authors = set()
        for roll in rolls:
            for commit in roll.commits:
                if commit.author:
                    authors.add(commit.author)
                if commit.owner:
                    authors.add(commit.owner)
        return authors

    def fake_author(self, author: Account) -> Account:
        # Update the author's email address so it can be used for attribution
        # without literally attributing it to the author's account in Gerrit.
        # Make sure not to add it twice, and there's no need to do this for
        # service accounts.
        email = author.email
        prefix = 'pigweed.infra.roller.'
        if prefix not in email and not email.endswith('gserviceaccount.com'):
            user, domain = author.email.split('@')
            email = f'{user}@{prefix}{domain}'

        return Account(
            author.name,
            email,
        )

    def reviewers(self, *rolls: Roll) -> set[Account]:
        reviewers = set()
        for roll in rolls:
            for commit in roll.commits:
                reviewers.update(commit.reviewers)
        return reviewers

    def can_cc_on_roll(self, email: str, host: str) -> bool:
        # Assume all queried accounts exist on Gerrit in testing except for
        # nobody@google.com.
        test_data = self.m.json.test_api.output([{'_account_id': 123}])
        if email == 'nobody@google.com':
            test_data = self.m.json.test_api.output([])

        return bool(
            self.m.gerrit.account_query(
                email,
                f'email:{email}',
                host=host,
                test_data=test_data,
            ).json.output
        )

    def include_cc(
        self,
        account: Account,
        cc_domains: Sequence[str],
        host: str,
    ):
        with self.m.step.nest(f'cc {account.email}') as pres:
            domain = account.email.split('@', 1)[1]
            if domain.endswith('gserviceaccount.com'):
                pres.step_summary_text = 'not CCing, robot account'
                return False
            if cc_domains and domain not in cc_domains:
                pres.step_summary_text = 'not CCing, domain excluded'
                return False
            if not self.can_cc_on_roll(account.email, host=host):
                pres.step_summary_text = 'not CCing, no account in Gerrit'
                return False

            pres.step_summary_text = 'CCing'
            return True

    def _single_commit_roll_message(self, roll: Roll) -> str:
        template = """
roll: {project_name}: {sanitized_message}

{remote}
{project_name} Rolled-Commits: {old_revision:.15}..{new_revision:.15}
        """.strip()

        commit = roll.commits[0]

        kwargs = {
            'project_name': roll.project_name,
            'remote': roll.remote,
            'original_message': commit.message,
            'sanitized_message': _sanitize_message(commit.message),
            'old_revision': roll.old_revision,
            'new_revision': roll.new_revision,
        }

        message = Message(
            name=roll.project_name,
            template=template,
            kwargs=kwargs,
            num_commits=1,
            footer=tuple(self.footer),
        )

        with self.m.step.nest(f'message for {roll.project_name}') as pres:
            pres.logs['template'] = template
            pres.logs['kwargs'] = _pprint_dict(kwargs)
            pres.logs['message'] = message.render()

        return message

    def _multiple_commits_roll_message(self, roll: Roll) -> str:
        template = """
roll: {project_name} {num_commits} commits

{one_liners}

{remote}
{project_name} Rolled-Commits: {old_revision:.15}..{new_revision:.15}
    """.strip()

        one_liners = []
        for commit in roll.commits:
            # Handle case where the commit message is empty. Example:
            # https://github.com/google/googletest/commit/148ab827cacc7a879832f40313bda87a65b1e8a3
            first_line = '(empty commit message)'
            if commit.message:
                first_line = commit.message.splitlines()[0]
            one_liners.append(f'{commit.hash:.15} {first_line[0:50]}')

        num_commits = len(roll.commits)

        if not _is_hash(roll.old_revision):
            num_commits = 'multiple'
            one_liners.append('...')

        if len(one_liners) > 500:
            one_liners = one_liners[0:100] + ['...'] + one_liners[-100:]
            # In case both this and the previous condition match.
            if one_liners[-1] == '...':
                one_liners.pop()  # pragma: no cover

        kwargs = {
            'project_name': roll.project_name,
            'remote': roll.remote,
            'num_commits': num_commits,
            'one_liners': '\n'.join(one_liners),
            'old_revision': roll.old_revision,
            'new_revision': roll.new_revision,
        }

        message = Message(
            name=roll.project_name,
            template=template,
            kwargs=kwargs,
            num_commits=num_commits,
            footer=tuple(self.footer),
        )

        with self.m.step.nest('message') as pres:
            pres.logs['template'] = template
            pres.logs['kwargs'] = _pprint_dict(kwargs)
            pres.logs['message'] = message.render()

        return message

    def _single_roll_message(self, roll: Roll) -> str:
        if len(roll.commits) > 1:
            return self._multiple_commits_roll_message(roll)
        return self._single_commit_roll_message(roll)

    def _multiple_rolls_message(self, *rolls: Roll):
        rolls = sorted(rolls, key=lambda x: x.project_name)

        messages = []
        for roll in rolls:
            messages.append(self._single_roll_message(roll))

        texts = [
            'roll: {}: Roll {} commits'.format(
                ', '.join(x.name for x in messages),
                sum(x.num_commits for x in messages),
            )
        ]
        texts.extend(x.render(with_footer=False) for x in messages)
        texts.append('\n'.join(f'{x}' for x in self.footer))

        return '\n\n'.join(texts)

    def create_roll(self, **kwargs) -> Roll:
        """Create a Roll. See Roll class above for details."""
        return Roll(api=self.m, **kwargs)

    def message(self, *rolls: Roll) -> str:
        with self.m.step.nest('roll message'):
            if len(rolls) > 1:
                result = self._multiple_rolls_message(*rolls)
            else:
                result = self._single_roll_message(*rolls).render()
            if self._commit_divider:
                result += f'\n{self._commit_divider}'
            return result

    def get_roll_direction(
        self,
        git_dir: config_types.Path,
        old: str,
        new: str,
        name: str = 'get roll direction',
    ) -> Direction:
        """Return Direction of roll."""
        if old == new:
            with self.m.step.nest(name) as pres:
                pres.step_summary_text = 'up-to-date'
            return Direction.CURRENT

        with self.m.context(git_dir):
            with self.m.step.nest(name) as pres:
                forward = self.m.git(
                    'is forward',
                    'merge-base',
                    '--is-ancestor',
                    old,
                    new,
                    ok_ret=(0, 1),
                )

                backward = self.m.git(
                    'is backward',
                    'merge-base',
                    '--is-ancestor',
                    new,
                    old,
                    ok_ret=(0, 1),
                )

                if (
                    forward.exc_result.retcode == 0
                    and backward.exc_result.retcode != 0
                ):
                    pres.step_summary_text = 'forward'
                    return Direction.FORWARD

                if (
                    forward.exc_result.retcode != 0
                    and backward.exc_result.retcode == 0
                ):
                    pres.step_summary_text = 'backward'
                    return Direction.BACKWARD

                # If new and old are ancestors of each other then this is the
                # same commit. We should only hit this during testing because
                # the comparison at the top of the function should have caught
                # this situation.
                if (
                    forward.exc_result.retcode == 0
                    and backward.exc_result.retcode == 0
                ):
                    with self.m.step.nest(name) as pres:
                        pres.step_summary_text = 'up-to-date'
                    return Direction.CURRENT

                # If old is not an ancestor of new and new is not an ancestor
                # of old then history was rewritten in some manner but we still
                # need to update the pin.
                pres.step_summary_text = 'rebase'
                return Direction.REBASE

    def can_roll(self, direction: Direction) -> bool:
        return direction in (Direction.FORWARD, Direction.REBASE)

    def skip_roll_step(self, remote: str, old_revision: str, new_revision: str):
        with self.m.step.nest('cancelling roll') as pres:
            fmt = (
                'not updating from {old} to {new} because {old} is newer '
                'than {new}'
            )
            if old_revision == new_revision:
                fmt = (
                    'not updating from {old} to {new} because they are '
                    'identical'
                )
            pres.step_summary_text = fmt.format(
                old=old_revision[0:7], new=new_revision[0:7]
            )
            pres.links[old_revision] = f'{remote}/+/{old_revision}'
            pres.links[new_revision] = f'{remote}/+/{new_revision}'

    def normalize_remote(self, remote: str, base: str) -> str:
        """Convert relative paths to absolute paths.

        Support relative paths. If the top-level project is
        "https://pigweed.googlesource.com/ex/ample" then a submodule path of
        "./abc" maps to "https://pigweed.googlesource.com/ex/ample/abc" and
        "../abc" maps to "https://pigweed.googlesource.com/ex/abc". Minimal
        error-checking because git does most of these checks for us.

        Also converts sso to https.

        Args:
            remote (str): Submodule remote URL.
            base (str): Fully-qualified superproject remote URL.
        """
        if remote.startswith('.'):
            remote = '/'.join((base.rstrip('/'), remote.lstrip('/')))

            changes = 1
            while changes:
                changes = 0

                remote, n = re.subn(r'/\./', '/', remote)
                changes += n

                remote, n = re.subn(r'/[^/]+/\.\./', '/', remote)
                changes += n

        return self.m.sso.sso_to_https(remote)

    def merge_auto_roller_overrides(
        self,
        auto_roller_options: AutoRollerOptions,
        override_auto_roller_options: AutoRollerOptions,
    ):
        result = AutoRollerOptions()
        result.CopyFrom(auto_roller_options)
        result.MergeFrom(override_auto_roller_options)
        return result
