# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Utility functions for rollers."""

import collections
import re

import attr
import enum
from recipe_engine import recipe_api
from six.moves import urllib

# If we're embedding the original commit message, prepend 'Original-' to lines
# which begin with these tags.
ESCAPE_TAGS = [
    'Bug:',
    'Fixed:',
    'Fixes:',
    'Requires:',
    'Reviewed-on:',
]

# If we're embedding the original commit message, remove lines which contain
# these tags.
FILTER_TAGS = [
    'API-Review:',
    'Acked-by:',
    'Auto-Submit',
    re.compile(r'^\w+-Auto-Submit:'),
    'CC:',
    'CQ-Do-Not-Cancel-Tryjobs:',
    'Cq-Include-Trybots:',
    'Change-Id:',
    'Commit-Queue:',
    'Cq-Cl-Tag:',
    'No-Docs-Update-Reason:',
    'No-Presubmit:',
    'No-Tree-Checks: true',
    'No-Try: true',
    'Reviewed-by:',
    'Roller-URL:',
    'Signed-off-by:',
    'Testability-Review:',
    'Tested-by:',
]


def _match_tag(line, tag):
    if hasattr(tag, 'match'):
        return tag.match(line)
    return line.startswith(tag)


def _sanitize_message(message):
    """Sanitize lines of a commit message.

    Prepend 'Original-' to lines which begin with ESCAPE_TAGS. Filter
    out lines which begin with FILTER_TAGS.
    """
    return '\n'.join(
        "Original-" + line
        if any((line.startswith(tag) for tag in ESCAPE_TAGS))
        else line
        for line in message.splitlines()
        if not any((_match_tag(line, tag) for tag in FILTER_TAGS))
    )


class _Direction(enum.Enum):
    CURRENT = 'CURRENT'
    FORWARD = 'FORWARD'
    BACKWARD = 'BACKWARD'
    REBASE = 'REBASE'


@attr.s
class Commit(object):
    hash = attr.ib(type=str)
    message = attr.ib(type=str)
    author = attr.ib(type=str)
    owner = attr.ib(type=str)
    reviewers = attr.ib(type=tuple)


@attr.s
class Roll(object):
    _api = attr.ib()
    project_name = attr.ib(type=str)
    old_revision = attr.ib(type=str)
    new_revision = attr.ib(type=str)
    proj_dir = attr.ib(type=str)
    direction = attr.ib(type=str)
    commits = attr.ib(type=tuple, default=None)
    remote = attr.ib(type=str, default=None)
    _nest_steps = attr.ib(type=bool, default=True)

    @direction.validator
    def check(self, _, value):  # pragma: no cover
        if value not in _Direction:
            raise ValueError('invalid direction: {}'.format(value))
        if value == _Direction.CURRENT:
            raise ValueError('attempt to do a no-op roll')

    def __attrs_post_init__(self):
        self._set_remote()
        with self._api.context(cwd=self.proj_dir):
            if self._nest_steps:
                with self._api.step.nest(self.project_name):
                    self._set_commits()
            else:
                self._set_commits()  # pragma: no cover

    def _set_commits(self):

        log_cmd = [
            'log',
            '--pretty=format:%H\n%ae\n%B',
            # Separate entries with null bytes since most entries
            # will contain newlines ("%B" is the full commit
            # message, not just the first line.)
            '-z',
        ]

        if _is_hash(self.old_revision) and self.direction == _Direction.FORWARD:
            log_cmd.append(
                '{}..{}'.format(self.old_revision, self.new_revision)
            )
        else:
            log_cmd.extend(('--max-count', '5', self.new_revision))

        log_kwargs = {'stdout': self._api.raw_io.output_text()}

        commit_log = (
            self._api.git('git log', *log_cmd, **log_kwargs)
            .stdout.strip('\0')
            .split('\0')
        )

        commits = []
        for i, commit in enumerate(commit_log):
            commit_hash, author, message = commit.split('\n', 2)
            owner = None
            reviewers = []

            full_host = '{}-review.googlesource.com'.format(self.gerrit_name)

            changes = []

            # If there are a lot of CLs in this roll only get owner and
            # reviewer data from the first 10 so we don't make too many
            # requests of Gerrit.
            if i < 10:
                changes = self._api.gerrit.change_query(
                    'get change-id',
                    'commit:{}'.format(commit_hash),
                    host=full_host,
                    test_data=self._api.json.test_api.output(
                        [{'_number': 12345}]
                    ),
                ).json.output

            if changes and len(changes) == 1:
                number = changes[0]['_number']
                step = self._api.gerrit.change_details(
                    'get {}'.format(number),
                    number,
                    host=full_host,
                    test_data=self._api.json.test_api.output(
                        {
                            'owner': {'email': 'owner@example.com'},
                            'reviewers': {
                                'REVIEWER': [
                                    {'email': 'reviewer@example.com'},
                                    {'email': 'nobody@google.com'},
                                    {'email': 'robot@gserviceaccount.com'},
                                ],
                            },
                        }
                    ),
                    ok_ret='any',
                )

                if step.exc_result.retcode == 0:
                    details = step.json.output
                    owner = details['owner']['email']
                    for reviewer in details['reviewers']['REVIEWER']:
                        reviewers.append(reviewer['email'])

            commits.append(
                Commit(
                    hash=commit_hash,
                    author=author,
                    owner=owner,
                    reviewers=tuple(reviewers),
                    message=message,
                )
            )

        self.commits = tuple(commits)

    def _set_remote(self):
        api = self._api

        with api.step.nest('remote'), api.context(cwd=self.proj_dir):
            name = api.git(
                'name',
                'remote',
                stdout=api.raw_io.output_text(),
                step_test_data=lambda: api.raw_io.test_api.stream_output_text(
                    'origin'
                ),
            ).stdout.strip()

            remote = api.git(
                'url',
                'remote',
                'get-url',
                name,
                stdout=api.raw_io.output_text(),
                step_test_data=lambda: api.raw_io.test_api.stream_output_text(
                    'sso://pigweed/pigweed/pigweed'
                ),
            ).stdout.strip()

            self.remote = api.sso.sso_to_https(remote)

    @property
    def gerrit_name(self):
        return urllib.parse.urlparse(self.remote).netloc.split('.')[0]


@attr.s
class Message(object):
    name = attr.ib(type=str)
    template = attr.ib(type=str)
    kwargs = attr.ib(type=dict)
    num_commits = attr.ib(type=int)
    footer = attr.ib(type=tuple, default=())

    def render(self, with_footer=True):
        result = [self.template.format(**self.kwargs)]
        if with_footer:
            result.extend(x for x in self.footer)
        return '\n'.join(result)


def _is_hash(value):
    return re.match(r'^[0-9a-fA-F]{40}', value)


def _pprint_dict(d):
    result = []
    for k, v in sorted(d.items()):
        result.append('{!r}: {!r}\n'.format(k, v))
    return ''.join(result)


class RollUtilApi(recipe_api.RecipeApi):
    def __init__(self, props, *args, **kwargs):
        super(RollUtilApi, self).__init__(*args, **kwargs)
        self.labels_to_set = collections.OrderedDict()
        for label in sorted(props.labels_to_set, key=lambda x: x.label):
            self.labels_to_set[str(label.label)] = label.value
        self.labels_to_wait_on = sorted(str(x) for x in props.labels_to_wait_on)
        self.footer = []

    def authors(self, *roll):
        authors = set()
        for r in roll:
            for commit in r.commits:
                if commit.author:
                    authors.add(commit.author)
                if commit.owner:
                    authors.add(commit.owner)
        return authors

    def reviewers(self, *roll):
        reviewers = set()
        for r in roll:
            for commit in r.commits:
                reviewers.update(commit.reviewers)
        return reviewers

    def can_cc_on_roll(self, email, host):
        # Assume all queried accounts exist on Gerrit in testing except for
        # nobody@google.com.
        test_data = self.m.json.test_api.output([{'_account_id': 123}])
        if email == 'nobody@google.com':
            test_data = self.m.json.test_api.output([])

        return self.m.gerrit.account_query(
            email, 'email:{}'.format(email), host=host, test_data=test_data,
        ).json.output

    def include_cc(self, email, cc_domains, host):
        with self.m.step.nest('cc {}'.format(email)) as pres:
            domain = email.split('@', 1)[1]
            if domain.endswith('gserviceaccount.com'):
                pres.step_summary_text = 'not CCing, robot account'
                return False
            if cc_domains and domain not in cc_domains:
                pres.step_summary_text = 'not CCing, domain excluded'
                return False
            if not self.can_cc_on_roll(email, host=host):
                pres.step_summary_text = 'not CCing, no account in Gerrit'
                return False

            pres.step_summary_text = 'CCing'
            return True

    def _single_commit_roll_message(self, roll):
        template = """
[roll {project_name}] {sanitized_message}

{remote}
{project_name} Rolled-Commits: {old_revision:.15}..{new_revision:.15}
        """.strip()

        commit = roll.commits[0]

        kwargs = {
            'project_name': roll.project_name,
            'remote': roll.remote,
            'original_message': commit.message,
            'sanitized_message': _sanitize_message(commit.message),
            'old_revision': roll.old_revision,
            'new_revision': roll.new_revision,
        }

        message = Message(
            name=roll.project_name,
            template=template,
            kwargs=kwargs,
            num_commits=1,
            footer=tuple(self.footer),
        )

        with self.m.step.nest(
            'message for {}'.format(roll.project_name)
        ) as pres:
            pres.logs['template'] = template
            pres.logs['kwargs'] = _pprint_dict(kwargs)
            pres.logs['message'] = message.render()

        return message

    def _multiple_commits_roll_message(self, roll):
        template = """
[{project_name}] Roll {num_commits} commits

{one_liners}

{remote}
{project_name} Rolled-Commits: {old_revision:.15}..{new_revision:.15}
    """.strip()

        one_liners = [
            '{:.15} {}'.format(commit.hash, commit.message.splitlines()[0][:50])
            for commit in roll.commits
        ]

        num_commits = len(roll.commits)
        if not _is_hash(roll.old_revision):
            num_commits = 'multiple'
            one_liners.append('...')

        kwargs = {
            'project_name': roll.project_name,
            'remote': roll.remote,
            'num_commits': num_commits,
            'one_liners': '\n'.join(one_liners),
            'old_revision': roll.old_revision,
            'new_revision': roll.new_revision,
        }

        message = Message(
            name=roll.project_name,
            template=template,
            kwargs=kwargs,
            num_commits=num_commits,
            footer=tuple(self.footer),
        )

        with self.m.step.nest('message') as pres:
            pres.logs['template'] = template
            pres.logs['kwargs'] = _pprint_dict(kwargs)
            pres.logs['message'] = message.render()

        return message

    def _single_roll_message(self, roll):
        if len(roll.commits) > 1:
            return self._multiple_commits_roll_message(roll)
        return self._single_commit_roll_message(roll)

    def _multiple_rolls_message(self, *rolls):
        rolls = sorted(rolls, key=lambda x: x.project_name)

        messages = []
        for roll in rolls:
            messages.append(self._single_roll_message(roll))

        texts = [
            '[roll {}] Roll {} commits'.format(
                ', '.join(x.name for x in messages),
                sum(x.num_commits for x in messages),
            )
        ]
        texts.extend(x.render(with_footer=False) for x in messages)
        texts.append('\n'.join('{}'.format(x) for x in self.footer))

        return '\n\n'.join(texts)

    def Roll(self, **kwargs):
        """Create a Roll. See Roll class above for details."""
        return Roll(api=self.m, **kwargs)

    def message(self, *rolls):
        with self.m.step.nest('roll message'):
            if len(rolls) > 1:
                return self._multiple_rolls_message(*rolls)
            return self._single_roll_message(*rolls).render()

    Direction = _Direction

    def get_roll_direction(self, git_dir, old, new, name='get roll direction'):
        """Return Direction of roll."""
        if old == new:
            with self.m.step.nest(name) as pres:
                pres.step_summary_text = 'up-to-date'
            return self.Direction.CURRENT

        with self.m.context(git_dir):
            with self.m.step.nest(name) as pres:
                forward = self.m.git(
                    'is forward',
                    'merge-base',
                    '--is-ancestor',
                    old,
                    new,
                    ok_ret=(0, 1),
                )

                backward = self.m.git(
                    'is backward',
                    'merge-base',
                    '--is-ancestor',
                    new,
                    old,
                    ok_ret=(0, 1),
                )

                if (
                    forward.exc_result.retcode == 0
                    and backward.exc_result.retcode != 0
                ):
                    pres.step_summary_text = 'forward'
                    return self.Direction.FORWARD

                if (
                    forward.exc_result.retcode != 0
                    and backward.exc_result.retcode == 0
                ):
                    pres.step_summary_text = 'backward'
                    return self.Direction.BACKWARD

                # If new and old are ancestors of each other then this is the
                # same commit. We should only hit this during testing because
                # the comparison at the top of the function should have caught
                # this situation.
                if (
                    forward.exc_result.retcode == 0
                    and backward.exc_result.retcode == 0
                ):
                    with self.m.step.nest(name) as pres:
                        pres.step_summary_text = 'up-to-date'
                    return self.Direction.CURRENT

                # If old is not an ancestor of new and new is not an ancestor
                # of old then history was rewritten in some manner but we still
                # need to update the pin.
                pres.step_summary_text = 'rebase'
                return self.Direction.REBASE

    def can_roll(self, direction):
        return direction in (self.Direction.FORWARD, self.Direction.REBASE)

    def skip_roll_step(self, remote, old_revision, new_revision):
        with self.m.step.nest('cancelling roll') as pres:
            fmt = (
                'not updating from {old} to {new} because {old} is newer '
                'than {new}'
            )
            if old_revision == new_revision:
                fmt = (
                    'not updating from {old} to {new} because they are '
                    'identical'
                )
            pres.step_summary_text = fmt.format(
                old=old_revision[0:7], new=new_revision[0:7]
            )
            pres.links[old_revision] = '{}/+/{}'.format(remote, old_revision)
            pres.links[new_revision] = '{}/+/{}'.format(remote, new_revision)

    def normalize_remote(self, remote, base):
        """Convert relative paths to absolute paths.

        Support relative paths. If the top-level project is
        "https://pigweed.googlesource.com/ex/ample" then a submodule path of
        "./abc" maps to "https://pigweed.googlesource.com/ex/ample/abc" and
        "../abc" maps to "https://pigweed.googlesource.com/ex/abc". Minimal
        error-checking because git does most of these checks for us.

        Also converts sso to https.

        Args:
            remote (str): Submodule remote URL.
            base (str): Fully-qualified superproject remote URL.
        """
        if remote.startswith('.'):
            remote = '/'.join((base.rstrip('/'), remote.lstrip('/')))

            changes = 1
            while changes:
                changes = 0

                remote, n = re.subn(r'/\./', '/', remote)
                changes += n

                remote, n = re.subn(r'/[^/]+/\.\./', '/', remote)
                changes += n

        return self.m.sso.sso_to_https(remote)
