# Copyright 2023 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Write a binary over serial to provision the Gonk FPGA."""

import argparse
import importlib.resources
from itertools import islice
import logging
import operator
from pathlib import Path
import sys
import time
from typing import Optional, Iterable

import serial
from serial import Serial
from serial.tools.list_ports import comports
from serial.tools.miniterm import Miniterm, Transform

from pw_cli import log as pw_cli_log
from pw_console import python_logging
import pw_cli.color
from pw_tokenizer import (
    database as pw_tokenizer_database,
    detokenize,
    tokens,
)

from gonk_tools.gonk_log_stream import GonkLogStream

_ROOT_LOG = logging.getLogger()
_LOG = logging.getLogger('host')
_DEVICE_LOG = logging.getLogger('gonk')

_COLOR = pw_cli.color.colors()

_BUNDLED_FPGA_BINFILE = ''
_BUNDLED_FPGA_BINFILE_NAME = 'toplevel.bin'

# Check for a bundled FPGA bitstream file.
try:
    with importlib.resources.as_file(
            importlib.resources.files('gonk_fpga') /
            _BUNDLED_FPGA_BINFILE_NAME) as bin_path:
        if bin_path.is_file():
            _BUNDLED_FPGA_BINFILE = _BUNDLED_FPGA_BINFILE_NAME
except ModuleNotFoundError:
    pass


def _parse_args():
    parser = argparse.ArgumentParser(
        description=__doc__,
        formatter_class=argparse.RawDescriptionHelpFormatter,
    )
    parser.add_argument(
        '--databases',
        metavar='elf_or_token_database',
        action=pw_tokenizer_database.LoadTokenDatabases,
        nargs='+',
        help=('ELF or token database files from which to read strings and '
              'tokens. For ELF files, the tokenization domain to read from '
              'may specified after the path as #domain_name (e.g. '
              'foo.elf#TEST_DOMAIN). Unless specified, only the default '
              'domain ("") is read from ELF files; .* reads all domains. '
              'Globs are expanded to compatible database files.'),
    )
    parser.add_argument(
        '-p',
        '--port',
        type=str,
        help='Serial port path.',
    )
    parser.add_argument(
        '-b',
        '--baudrate',
        type=int,
        default=2000000,
        help='Sets the baudrate for serial communication.',
    )
    parser.add_argument(
        '--product',
        default='GENERIC_F730R8TX',
        help='Use first serial port matching this product name.',
    )
    parser.add_argument(
        '--serial-number',
        help='Use the first serial port matching this number.',
    )
    parser.add_argument(
        '--bitstream-file',
        type=Path,
        help=('FPGA Bitstream file. Can be a filesystem path or "DEFAULT" to '
              'use a Gonk Python tools bundled binary file.'),
    )
    parser.add_argument(
        '--logfile',
        type=Path,
        default=Path('gonk-logs.txt'),
        help=('Default log file. This will contain host side '
              'log messages only unles the '
              '--merge-device-and-host-logs argument is used.'),
    )
    parser.add_argument(
        '--merge-device-and-host-logs',
        action='store_true',
        help=('Include device logs in the default --logfile.'
              'These are normally shown in a separate device '
              'only log file.'),
    )
    parser.add_argument(
        '--log-colors',
        action=argparse.BooleanOptionalAction,
        help=('Colors in log messages.'),
    )
    parser.add_argument(
        '--log-to-stderr',
        action='store_true',
        help=('Log to STDERR'),
    )
    parser.add_argument(
        '--host-logfile',
        type=Path,
        help=('Additional host only log file. Normally all logs in the '
              'default logfile are host only.'),
    )
    parser.add_argument(
        '--device-logfile',
        type=Path,
        default=Path('gonk-device-logs.txt'),
        help='Device only log file.',
    )
    parser.add_argument(
        '--json-logfile',
        help='Device only JSON formatted log file.',
        type=Path,
    )

    return parser.parse_args()


def get_serial_port(
    product: Optional[str] = None,
    serial_number: Optional[str] = None,
) -> str:
    """Return serial ports that match give serial numbers or product names."""

    ports = sorted(comports(), key=operator.attrgetter('device'))

    # Return devices matching serial numbers first.
    for port in ports:
        if (serial_number is not None and port.serial_number is not None
                and serial_number in port.serial_number):
            return port.device
    # If no matching serial numbers, check for matching product names.
    for port in ports:
        if (product is not None and port.product is not None
                and product in port.product):
            return port.device

    # No matches found.
    return ''


FILE_LENGTH = 135100
FILE_START_STR = 'FF 00 00 FF'
SYNC_START_STR = '7E AA 99 7E'
FILE_START_BYTES = bytes.fromhex(FILE_START_STR)
SYNC_START_BYTES = bytes.fromhex(SYNC_START_STR)


class UnknownSerialDevice(Exception):
    """Exception raised when no device is specified."""


class IncorrectBinaryFormat(Exception):
    """Exception raised when FPGA bitstream file is in an unexpected format."""


def batched(iterable, n):
    """Batch data into tuples of length n. The last batch may be shorter.

       Example usage:

       .. code-block:: pycon

          >>> list(batched('ABCDEFG', 3))
          ['ABC', 'DEF', 'G']
    """
    if n < 1:
        raise ValueError('n must be at least one')
    it = iter(iterable)
    while batch := tuple(islice(it, n)):
        yield batch


class HandleBinaryData(Transform):
    """Miniterm transform to handle incoming byte data."""
    def __init__(self, gonk_log_stream: GonkLogStream) -> None:
        self.gonk_log_stream = gonk_log_stream

    def rx(self, text: str, data: Optional[bytes] = None) -> str:
        """Text received from the serial port."""
        if data:
            self.gonk_log_stream.write(data)
            return ''
        return text


class DebugSerialIO(Transform):
    """Print sent and received data to stderr."""
    def rx(self, text: str) -> str:
        """Text received from the serial port."""
        sys.stderr.write('[Recv: {!r}] '.format(text))
        sys.stderr.flush()
        return text

    def tx(self, text: str):
        """Text to be sent to the serial port."""
        sys.stderr.write(' [Send: {!r}] '.format(text))
        sys.stderr.flush()
        return text

    def echo(self, text: str):
        """Text to be sent but displayed on console."""
        return text


class MinitermBinary(Miniterm):
    def reader(self):
        """loop and copy serial->console"""
        # pylint: disable=too-many-nested-blocks
        try:
            while self.alive and self._reader_alive:
                # Read all that is there or wait for one byte
                data = self.serial.read(self.serial.in_waiting or 1)
                if data:
                    if self.raw:
                        self.console.write_bytes(data)
                    else:
                        text = self.rx_decoder.decode(data)
                        for transformation in self.rx_transformations:
                            if isinstance(transformation, HandleBinaryData):
                                text = transformation.rx(text, data)
                            else:
                                text = transformation.rx(text)
                        self.console.write(text)
        except serial.SerialException:
            self.alive = False
            self.console.cancel()
            raise  # XXX handle instead of re-raise?


def load_bitstream_file(bitstream_file: Path) -> bytes:
    """Check for valid bitstream file and load it as bytes."""
    bitstream_bytes = b''

    if str(bitstream_file) == 'DEFAULT':
        if not _BUNDLED_FPGA_BINFILE:
            raise FileNotFoundError('No default bitstream file is available.')

        bitstream_path = importlib.resources.files('gonk_fpga').joinpath(
            _BUNDLED_FPGA_BINFILE)
        bitstream_bytes = bitstream_path.read_bytes()
    else:
        if not bitstream_file.is_file():
            raise FileNotFoundError(f'\nUnable to load "{bitstream_file}"')
        bitstream_bytes = bitstream_file.read_bytes()

    if (len(bitstream_bytes) != 135100
            or bitstream_bytes[0:8] != FILE_START_BYTES + SYNC_START_BYTES):
        raise IncorrectBinaryFormat(
            f'the bitstream file must be:\n'
            f'  {FILE_LENGTH} bytes in length\n'
            f'  Start with "{FILE_START_STR} {SYNC_START_STR}"')

    return bitstream_bytes


def write_bitstream_file(bitstream_bytes: bytes,
                         serial_instance: Serial) -> None:
    """Write a series of bytes to serial."""

    # Write out the bitstream in batches.
    _LOG.info('Sending bitstream...')
    written_bytes: int = 0
    for byte_batch in batched(bitstream_bytes, 8):
        result = serial_instance.write(byte_batch)
        if result:
            written_bytes += result
        serial_instance.flush()
    _LOG.info('Done sending bitstream. Wrote %d', written_bytes)


def main(
    # pylint: disable=too-many-arguments
    baudrate: int,
    databases: Iterable,
    bitstream_file: Optional[Path],
    port: Optional[str] = None,
    product: Optional[str] = None,
    serial_number: Optional[str] = None,
    logfile: Optional[str] = None,
    host_logfile: Optional[str] = None,
    device_logfile: Optional[str] = None,
    json_logfile: Optional[str] = None,
    log_colors: Optional[bool] = False,
    merge_device_and_host_logs: bool = False,
    verbose: bool = False,
    log_to_stderr: bool = False,
) -> int:
    """Write a bitstream file over serial while monitoring output."""

    if not logfile:
        # Create a temp logfile to prevent logs from appearing over stdout. This
        # would corrupt the prompt toolkit UI.
        logfile = python_logging.create_temp_log_file()

    if log_colors is None:
        log_colors = True
    colors = pw_cli.color.colors(log_colors)
    log_level = logging.DEBUG if verbose else logging.INFO
    logger_name_format = colors.cyan('%(name)s')
    logger_message_format = f'[{logger_name_format}] %(levelname)s %(message)s'

    pw_cli_log.install(
        level=log_level,
        use_color=log_colors,
        hide_timestamp=False,
        log_file=logfile,
        message_format=logger_message_format,
    )

    if device_logfile:
        pw_cli_log.install(
            level=log_level,
            use_color=log_colors,
            hide_timestamp=False,
            log_file=device_logfile,
            logger=_DEVICE_LOG,
            message_format=logger_message_format,
        )
    if host_logfile:
        pw_cli_log.install(
            level=log_level,
            use_color=log_colors,
            hide_timestamp=False,
            log_file=host_logfile,
            logger=_ROOT_LOG,
            message_format=logger_message_format,
        )

    # By default don't send device logs to the root logger.
    _DEVICE_LOG.propagate = False
    if merge_device_and_host_logs:
        # Add device logs to the default logfile.
        pw_cli_log.install(
            level=log_level,
            use_color=log_colors,
            hide_timestamp=False,
            log_file=logfile,
            logger=_DEVICE_LOG,
            message_format=logger_message_format,
        )

    if json_logfile:
        json_filehandler = logging.FileHandler(json_logfile, encoding='utf-8')
        json_filehandler.setLevel(log_level)
        json_filehandler.setFormatter(python_logging.JsonLogFormatter())
        _DEVICE_LOG.addHandler(json_filehandler)

    if log_to_stderr:
        pw_cli_log.install(
            level=log_level,
            use_color=log_colors,
            hide_timestamp=False,
            message_format=logger_message_format,
        )

    _LOG.setLevel(log_level)
    _DEVICE_LOG.setLevel(log_level)
    _ROOT_LOG.setLevel(log_level)

    # Init serial port.
    if port is None:
        port = get_serial_port(product=product, serial_number=serial_number)

    if not port:
        raise UnknownSerialDevice(
            'No --serial-number --product or --port path provided.')

    serial_instance = Serial(port=port, baudrate=baudrate, timeout=0.1)

    detokenizer = detokenize.Detokenizer(tokens.Database.merged(*databases),
                                         show_errors=True)

    # Use pyserial miniterm to monitor recieved data.
    miniterm = MinitermBinary(
        serial_instance,
        echo=True,
        eol='lf',
        filters=(),
    )

    # Use Ctrl-C as the exit character. (Miniterm default is Ctrl-])
    miniterm.exit_character = chr(0x03)
    miniterm.set_rx_encoding('utf-8', errors='backslashreplace')
    miniterm.set_tx_encoding('utf-8')

    gonk_log_stream = GonkLogStream(detokenizer)
    miniterm.rx_transformations.append(HandleBinaryData(gonk_log_stream))
    miniterm.tx_transformations.append(DebugSerialIO())

    # Start monitoring serial data.
    miniterm.start()

    # Send the bitstream_file.
    if bitstream_file:
        bitstream_bytes = load_bitstream_file(bitstream_file)
        # Wait a couple seconds to print early log messages from Gonk.
        time.sleep(2)
        write_bitstream_file(bitstream_bytes, serial_instance)

    # Wait for ctrl-c, then shutdown miniterm.
    try:
        miniterm.join(True)
    except KeyboardInterrupt:
        pass
    sys.stderr.write('\n--- exit ---\n')
    miniterm.join()
    miniterm.close()

    return 0


if __name__ == '__main__':
    sys.exit(main(**vars(_parse_args())))
