#!/usr/bin/env python3
# Copyright 2022 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.

# TODO(mohrr) Write unit tests for this.

import argparse
import configparser
import json
import pathlib
import re
import subprocess
import sys


def debug(*args, **kwargs):
    if True:
        kwargs['file'] = sys.stderr
        return print(*args, **kwargs)


def trace(func):
    def wrapped(*args, **kwargs):
        try:
            debug(f'== entering {func.__name__}')
            res = func(*args, **kwargs)
            debug(f'== exiting {func.__name__} with result {res}')
            return res
        except Exception as e:
            debug(f'== exiting {func.__name__} with exception {e}')
            raise

    return wrapped


@trace
def _git(path, *cmd):
    args = ['git'] + list(cmd)
    debug('running', args)
    proc = subprocess.run(args=args, cwd=path, capture_output=True)
    debug('result:', proc.stdout)
    return proc.stdout


@trace
def _parse_config_file(path):
    with path.open('r') as ins:
        # Remove any whitespace to the left of the key/value pair,
        # to ensure values are not interpreted as multiline even
        # when a mix of tabs and spaces is used for indentation.
        config_string = "".join(line.lstrip() for line in ins.readlines())

    parser = configparser.ConfigParser()
    parser.read_string(config_string)
    return parser


@trace
def _parse_gitmodules_file(git_root, data, to_process, prefix=None):
    debug('git_root', git_root)
    debug('data.keys()', data.keys())
    debug('len(to_process)', len(to_process))
    debug('prefix', prefix)
    gitmodules = git_root / '.gitmodules'
    if not gitmodules.is_file():
        return

    config = _parse_config_file(gitmodules)

    if prefix and prefix != '.':
        prefix += '/'
    else:
        prefix = ''

    for section_name, section in config.items():
        if section_name == 'DEFAULT':
            if section:
                raise ValueError(
                    f'non-empty DEFAULT section ({list(section.items())})'
                )
            continue
        match = re.search(r'^submodule "(.*)"$', section_name)
        if not match:
            raise ValueError(section_name)
        name = match.group(1)
        raw_path = section['path']
        path = f'{prefix}{raw_path}'.replace('\\', '/')

        data.setdefault(path, {})
        data[path]['path'] = path
        data[path]['name'] = f'{prefix}{name}'
        data[path]['url'] = section['url']
        data[path]['branch'] = section.get('branch')
        data[path]['update'] = section.get('update')
        data[path]['ignore'] = section.get('ignore')
        data[path]['shallow'] = section.get('shallow')
        data[path]['fetchRecurseSubmodules'] = section.get(
            'fetchRecurseSubmodules'
        )

        resolved_path = (git_root / raw_path).resolve()
        debug('appending', resolved_path)
        to_process.append(resolved_path)


@trace
def _parse_gitmodules(git_root, data, recursive):
    seen_submodules = set()
    to_process = [git_root]
    while to_process:
        curr = to_process.pop()
        debug('processing', curr)
        seen_submodules.add(curr)
        _parse_gitmodules_file(
            curr, data, to_process, prefix=str(curr.relative_to(git_root))
        )
        if not recursive:
            to_process.clear()


@trace
def _add_remotes(git_root, data):
    base = (
        _git(git_root, 'config', '--get', 'remote.origin.url').decode().strip()
    )

    for submodule in data.values():
        remote = submodule['url']
        if remote.startswith('.'):
            remote = '/'.join((base.rstrip('/'), remote.lstrip('/')))

            changes = 1
            while changes:
                changes = 0

                remote, n = re.subn(r'/\./', '/', remote)
                changes += n

                remote, n = re.subn(r'/[^/]+/\.\./', '/', remote)
                changes += n

        submodule['remote'] = remote


@trace
def _parse_submodule_status(line, data):
    """Parse a `git submodule status` and get the remote URL."""
    match = re.search(
        r'^(?P<status>[+-U]?)(?P<hash>[0-9a-fA-F]{40})\s+'
        r'(?P<path>[^()]*?)\s*'
        r'(?:\((?P<describe>[^()]*)\))?$',
        line.strip(),
    )
    if not match:
        raise ValueError('unrecognized submodule status line "{}"'.format(line))

    try:
        submodule = data[match.group('path')]
    except KeyError:
        debug(data.keys())
        raise
    submodule['initialized'] = match.group('status') != '-'
    submodule['modified'] = match.group('status') == '+'
    submodule['conflict'] = match.group('status') == 'U'
    submodule['hash'] = match.group('hash')
    submodule['describe'] = match.group('describe')


@trace
def _add_status(git_root, data, recursive):
    cmd = ['git', 'submodule', 'status']
    if recursive:
        cmd.append('--recursive')
    debug('running', cmd)
    proc = subprocess.run(cmd, cwd=git_root, capture_output=True)
    debug('result:', proc.stdout.decode())
    for line in proc.stdout.decode().splitlines():
        _parse_submodule_status(line, data)


@trace
def main(git_root, output_file, recursive):
    data = {}
    _parse_gitmodules(git_root, data, recursive)
    _add_remotes(git_root, data)
    _add_status(git_root, data, recursive)

    json.dump(data, output_file)


@trace
def parse(argv=None):
    parser = argparse.ArgumentParser()
    parser.add_argument('git_root', type=pathlib.Path)
    parser.add_argument('output_file', type=argparse.FileType('w'))
    parser.add_argument('--recursive', action='store_true')
    return parser.parse_args(argv)


if __name__ == '__main__':
    main(**vars(parse()))
    sys.exit(0)
