# Copyright 2024 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Gonk log stream handler."""

from datetime import datetime
import logging
import time
from typing import Optional

from google.protobuf.message import DecodeError
from pw_hdlc.decode import FrameDecoder
from pw_log.log_decoder import LogStreamDecoder
from pw_log.proto import log_pb2
from pw_tokenizer import detokenize
import pw_log.log_decoder

from gonk_adc.adc_measurement_pb2 import FramedProto
from gonk_tools.binary_handler import (
    BIN_LOG_SYNC_START_BYTES,
    BytearrayLoop,
)

_LOG = logging.getLogger('host')
_DEVICE_LOG = logging.getLogger('gonk')


def emit_python_log(log: pw_log.log_decoder.Log) -> None:
    log.metadata_fields['module'] = log.module_name
    log.metadata_fields['source_name'] = log.source_name
    log.metadata_fields['timestamp'] = log.timestamp
    log.metadata_fields['msg'] = log.message
    log.metadata_fields['file'] = log.file_and_line

    _DEVICE_LOG.log(
        log.level,
        '%s -- %s',
        log.message,
        log.file_and_line,
        extra=dict(extra_metadata_fields=log.metadata_fields),
    )


class GonkLogStream:
    """Handle incoming serial data from Gonk."""

    # pylint: disable=too-many-instance-attributes
    def __init__(
        self,
        detokenizer: Optional[detokenize.Detokenizer] = None,
        adc_time_format: str = '%Y%m%d %H:%M:%S.%f',
    ) -> None:
        self.detokenizer = detokenizer
        self.log_decoder: Optional[LogStreamDecoder] = None
        if self.detokenizer:
            self.log_decoder = LogStreamDecoder(
                decoded_log_handler=emit_python_log, detokenizer=detokenizer)
        self.input_buffer = BytearrayLoop()
        self.log_buffer = BytearrayLoop()
        self.frame_decoder = FrameDecoder()
        self.adc_time_format = adc_time_format

        self.remaining_proto_bytes: Optional[int] = None

        self.proto_update_count: int = 0
        self.proto_update_time = time.monotonic()

    def _maybe_log_update_rate(self) -> None:
        self.proto_update_count += 1
        current_time = time.monotonic()
        if current_time > self.proto_update_time + 1:
            _LOG.info('ADC updates/second: %d', self.proto_update_count)
            self.proto_update_count = 0
            self.proto_update_time = current_time

    def write(self, data: bytes) -> None:
        """Write data and separate protobuf data from logs."""
        self.input_buffer.write(data)

        logdata = self.input_buffer.read_until(BIN_LOG_SYNC_START_BYTES)
        if logdata:
            self.log_buffer.write(logdata)

        self._handle_log_data()
        self._handle_proto_packets()

    def _next_proto_packet_size(self) -> Optional[int]:
        packet_size: Optional[int] = None

        required_bytes = 8
        # Expected bytes:
        # 1: fixed32 tag
        # 4: sync start value bytes
        # 1: varint tag
        # 2: varint value (one to two bytes)
        head_bytes = self.input_buffer.peek(required_bytes)
        if len(
                head_bytes
        ) < required_bytes or BIN_LOG_SYNC_START_BYTES not in head_bytes:
            return packet_size

        _sync_start_bytes = head_bytes[:5]
        _varint_tag = head_bytes[5:6]
        # Decode 1-2 varint bytes.
        varint_byte1 = head_bytes[6:7]
        varint_byte2 = head_bytes[7:8]
        i = int.from_bytes(varint_byte1)
        packet_size = i & 0x7f
        if (i & 0x80) == 0x80:
            # Decode the extra varint byte
            packet_size |= (int.from_bytes(varint_byte2) & 0x7f) << 7
        else:
            # Extra varint byte value not used.
            required_bytes -= 1

        return required_bytes + packet_size

    def _handle_log_data(self) -> None:
        log_data = self.log_buffer.read()
        for frame in self.frame_decoder.process(log_data):
            if not frame.ok():
                _LOG.warning(
                    'Failed to decode frame: %s; discarded %d bytes',
                    frame,
                    len(frame.raw_encoded),
                )
                # Save this data back to the log_buffer
                self.log_buffer.write(frame.raw_encoded)
                continue

            if not self.log_decoder:
                _DEVICE_LOG.info('%s', frame)
                continue

            log_entry = log_pb2.LogEntry()
            try:
                log_entry.ParseFromString(frame.data)
            except DecodeError:
                # Try to detokenize the frame data and log the failure.
                detokenized_text = ''
                if self.detokenizer:
                    detokenized_text = str(
                        self.detokenizer.detokenize(frame.data))
                if detokenized_text:
                    _LOG.warning('Failed to parse log proto "%s" %s',
                                 detokenized_text, frame)
                else:
                    _LOG.warning('Failed to parse log proto %s', frame)
                continue

            log = self.log_decoder.parse_log_entry_proto(log_entry)
            emit_python_log(log)

    def _handle_proto_packets(self) -> None:
        # If no proto has been found, check for a new one.
        if self.remaining_proto_bytes is None:
            self.remaining_proto_bytes = self._next_proto_packet_size()

        # If a proto was found and bytes are pending.
        if (self.remaining_proto_bytes is not None
                and self.remaining_proto_bytes > 0):

            # If the remaining bytes are available in the buffer.
            if self.remaining_proto_bytes == len(
                    self.input_buffer.peek(self.remaining_proto_bytes)):
                # Pop the proto
                proto_bytes = bytes(
                    self.input_buffer.read(self.remaining_proto_bytes))
                # All bytes consumed
                self.remaining_proto_bytes = None
                self._parse_and_log_adc_proto(proto_bytes)
                self._maybe_log_update_rate()

    def _parse_and_log_adc_proto(self, proto_bytes: bytes) -> None:
        """Parse an ADC proto message and log."""
        framed_proto = FramedProto()

        try:
            framed_proto.ParseFromString(proto_bytes)
        except DecodeError:
            _LOG.error('ADC FramedProto.DecodeError: %s', proto_bytes.hex())
            return

        host_time = datetime.now().strftime(self.adc_time_format)
        packet_size = len(proto_bytes)
        delta_micros = framed_proto.payload.timestamp

        vbus_values = []
        vshunt_values = []
        for adc_measure in framed_proto.payload.adc_measurements:
            vbus_values.append(adc_measure.vbus_value)
            vshunt_values.append(adc_measure.vshunt_value)

        _DEVICE_LOG.info(
            'host_time: %s size: %s delta_microseconds: %s '
            'vbus: %s vshunt: %s',
            host_time,
            str(packet_size),
            str(delta_micros),
            ','.join(str(value) for value in vshunt_values),
            ','.join(str(value) for value in vbus_values),
            extra=dict(
                extra_metadata_fields={
                    'host_time': host_time,
                    'packet_size': packet_size,
                    'delta_micros': delta_micros,
                    'vbus_values': vbus_values,
                    'vshunt_values': vshunt_values,
                }),
        )
