# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Bazel-related functions."""

from __future__ import annotations

import dataclasses
import re
from typing import Any, Sequence, TypeVar, TYPE_CHECKING

from PB.recipe_modules.pigweed.bazel.options import Options
from recipe_engine import config_types, recipe_api

if TYPE_CHECKING:  # pragma: no cover
    from RECIPE_MODULES.pigweed.checkout import api as checkout_api


# This is copied from bazel.json in Pigweed, but there's no need to keep it
# up-to-date.
_JSON_TEST_DATA: dict[str, Any] = {
    "included_files": [],
    "packages": [
        {
            "path": "fuchsia/third_party/bazel/${platform}",
            "platforms": ["linux-amd64", "mac-amd64", "windows-amd64"],
            "tags": ["version:2@6.3.0.6"],
            "version_file": ".versions/bazel.cipd_version",
        },
        {
            "path": "flutter/java/openjdk/${platform}",
            "platforms": [
                "linux-amd64",
                "mac-amd64",
                "mac-arm64",
                "windows-amd64",
            ],
            "tags": ["version:17"],
        },
    ],
}


@dataclasses.dataclass
class BazelRunner:
    api: recipe_api.RecipeApi
    checkout_root: config_types.Path
    options: Options
    _bazel: config_types.Path | None = None

    def _ensure_cipd(self) -> config_types.Path:
        # read_json in a few lines requires absolute, normalized paths.
        json_path: config_types.Path = self.api.path.abspath(
            self.checkout_root / self.options.cipd_json_path
        )

        ensure_file = self.api.cipd.EnsureFile()
        for package in self.api.file.read_json(
            f'read {self.api.path.basename(json_path)}',
            json_path,
            test_data=_JSON_TEST_DATA,
        )["packages"]:
            ensure_file.add_package(package['path'], package['tags'][0])

        root = self.api.path.mkdtemp()
        self.api.cipd.ensure(root, ensure_file, name='ensure bazel')
        return root / 'bazel'

    def _ensure_bazelisk(self) -> config_types.Path:
        ensure_file = self.api.cipd.EnsureFile()
        ensure_file.add_package(
            'fuchsia/third_party/bazelisk/${platform}',
            self.options.bazelisk_version or 'latest',
        )

        root = self.api.path.mkdtemp()
        self.api.cipd.ensure(root, ensure_file, name='ensure bazelisk')
        return root / 'bazelisk'

    def ensure(self) -> config_types.Path:
        if self._bazel:
            return self._bazel

        if self.options.cipd_json_path:
            self._bazel = self._ensure_cipd()
        else:
            self._bazel = self._ensure_bazelisk()

        self.api.step('bazel version', [self._bazel, 'version'])
        return self._bazel

    def run(self, **kwargs) -> None:
        name: str = ' '.join(['bazel'] + list(self.options.args))
        with self.api.context(cwd=self.checkout_root):
            if self.options.args:
                self.api.step(
                    name, [self.ensure(), *self.options.args], **kwargs
                )

            for invocation in self.options.invocations:
                assert invocation.args
                name: str = ' '.join(['bazel'] + list(invocation.args))
                self.api.step(name, [self.ensure(), *invocation.args], **kwargs)


def nwise(iterable, n):
    # nwise('ABCDEFG', 3) → ABC BCD CDE DEF EFG
    # See also
    # https://docs.python.org/3/library/itertools.html#itertools.pairwise
    iterator = iter(iterable)
    initial_items = [None]
    for i in range(1, n):
        initial_items.append(next(iterator, None))
    items = tuple(initial_items)
    for x in iterator:
        items = (*items[1:], x)
        yield items


T = TypeVar('T')


def proximity_sort_nearby_lines(lines: Sequence[T]) -> list[T]:
    # Shift the order to be center-out instead of ascending.
    lines = [(abs(len(lines) // 2 - i), x) for i, x in enumerate(lines)]
    return [x[1] for x in sorted(lines)]


@dataclasses.dataclass
class UpdateCommitHashResult:
    old_revision: str
    project_name: str | None


class BazelApi(recipe_api.RecipeApi):
    """Bazel utilities."""

    BazelRunner = BazelRunner
    UpdateCommitHashResult = UpdateCommitHashResult

    def new_runner(
        self,
        checkout: checkout_api.CheckoutContext,
        options: Options | None,
    ) -> BazelRunner:
        return BazelRunner(self.m, checkout.root, options)

    def update_commit_hash(
        self,
        *,
        checkout: checkout_api.CheckoutContext,
        project_remote: str,
        new_revision: str,
        num_nearby_lines: int = 2,
        path: config_types.Path | None,
    ) -> UpdateCommitHashResult:
        if not path:
            path = checkout.root / 'WORKSPACE'

        lines = [''] * num_nearby_lines
        lines.extend(
            self.m.file.read_text(
                f'read old {path.name}',
                path,
                test_data=self.m.bazel.test_api.TEST_WORKSPACE_FILE,
            )
            .strip()
            .splitlines()
        )
        lines.extend([''] * num_nearby_lines)

        for nearby_lines in nwise(enumerate(lines), num_nearby_lines * 2 + 1):
            i, curr = nearby_lines[len(nearby_lines) // 2]
            match = re.search(
                r'^\s*remote\s*=\s*"(?P<remote>[^"]+)",?\s*$', curr
            )
            if not match:
                continue

            match_remote = match.group('remote')

            step = self.m.step.empty(f'found remote {match_remote!r}')
            if checkout.remotes_equivalent(match_remote, project_remote):
                step.presentation.step_summary_text = 'equivalent'
                break
            step.presentation.step_summary_text = 'not equivalent'

        else:
            self.m.step.empty(
                f'could not find remote {project_remote} in {path}',
                status='FAILURE',
            )

        nearby_lines = proximity_sort_nearby_lines(nearby_lines)

        commit_rx = re.compile(
            r'^(?P<prefix>\s*commit\s*=\s*")'
            r'(?P<commit>[0-9a-f]{40})'
            r'(?P<suffix>",?\s*)$'
        )

        for i, line in nearby_lines:
            if match := commit_rx.search(line):
                idx = i
                break
        else:
            self.m.step.empty(
                f'could not find commit line adjacent to {curr!r} in {path}',
                status='FAILURE',
            )

        old_revision = match.group('commit')

        prefix = match.group("prefix")
        suffix = match.group("suffix")
        lines[idx] = f'{prefix}{new_revision}{suffix}'

        project_name: str | None = None
        for i, line in nearby_lines:
            if match := re.search(
                r'^\s*name\s*=\s*"(?P<name>[^"]+)",?\s*$', line
            ):
                project_name = match.group('name')
                break

        self.m.file.write_text(
            f'write new {path.name}',
            path,
            ''.join(
                f'{x}\n' for x in lines[num_nearby_lines:-num_nearby_lines]
            ),
        )

        return UpdateCommitHashResult(
            old_revision=old_revision,
            project_name=project_name,
        )
