# Copyright 2022 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Classes for read and write transfers."""

import abc
import asyncio
from dataclasses import dataclass
import enum
import logging
import math
import threading
from typing import Any, Callable, Optional

from pw_status import Status
try:
    from pw_transfer import transfer_pb2
except ImportError:
    # For the bazel build, which puts generated protos in a different location.
    from pigweed.pw_transfer import transfer_pb2  # type: ignore

_LOG = logging.getLogger(__package__)


@dataclass(frozen=True)
class ProgressStats:
    bytes_sent: int
    bytes_confirmed_received: int
    total_size_bytes: Optional[int]

    def percent_received(self) -> float:
        if self.total_size_bytes is None:
            return math.nan

        return self.bytes_confirmed_received / self.total_size_bytes * 100

    def __str__(self) -> str:
        total = str(
            self.total_size_bytes) if self.total_size_bytes else 'unknown'
        return (f'{self.percent_received():5.1f}% ({self.bytes_sent} B sent, '
                f'{self.bytes_confirmed_received} B received of {total} B)')


ProgressCallback = Callable[[ProgressStats], Any]


class _Timer:
    """A timer which invokes a callback after a certain timeout."""
    def __init__(self, timeout_s: float, callback: Callable[[], Any]):
        self.timeout_s = timeout_s
        self._callback = callback
        self._task: Optional[asyncio.Task[Any]] = None

    def start(self, timeout_s: float = None) -> None:
        """Starts a new timer.

        If a timer is already running, it is stopped and a new timer started.
        This can be used to implement watchdog-like behavior, where a callback
        is invoked after some time without a kick.
        """
        self.stop()
        timeout_s = self.timeout_s if timeout_s is None else timeout_s
        self._task = asyncio.create_task(self._run(timeout_s))

    def stop(self) -> None:
        """Terminates a running timer."""
        if self._task is not None:
            self._task.cancel()
            self._task = None

    async def _run(self, timeout_s: float) -> None:
        await asyncio.sleep(timeout_s)
        self._task = None
        self._callback()


class Transfer(abc.ABC):
    """A client-side data transfer through a Manager.

    Subclasses are responsible for implementing all of the logic for their type
    of transfer, receiving messages from the server and sending the appropriate
    messages in response.
    """
    class _State(enum.Enum):
        # Transfer is starting. The server and client are performing an initial
        # handshake and negotiating protocol and feature flags.
        INITIATING = 0

        # Waiting for the other end to send a chunk.
        WAITING = 1

        # Transmitting a window of data to a receiver.
        TRANSMITTING = 2

        # Recovering after one or more chunks was dropped in an active transfer.
        RECOVERY = 3

        # Transfer has completed locally and is waiting for the peer to
        # acknowledge its final status. Only entered by the terminating side of
        # the transfer.
        #
        # The context remains in a TERMINATING state until it receives an
        # acknowledgement from the peer or times out.
        TERMINATING = 4

        # A transfer has fully completed.
        COMPLETE = 5

    def __init__(self,
                 session_id: int,
                 send_chunk: Callable[[transfer_pb2.Chunk], None],
                 end_transfer: Callable[['Transfer'], None],
                 response_timeout_s: float,
                 initial_response_timeout_s: float,
                 max_retries: int,
                 progress_callback: ProgressCallback = None):
        self.id = session_id
        self.status = Status.OK
        self.done = threading.Event()

        self._send_chunk = send_chunk
        self._end_transfer = end_transfer

        # TODO(b/243679452): Only the legacy protocol begins in a WAITING state.
        # Newer protocols should start as INITIATING.
        self._state = Transfer._State.WAITING

        self._retries = 0
        self._max_retries = max_retries
        self._response_timer = _Timer(response_timeout_s, self._on_timeout)
        self._initial_response_timeout_s = initial_response_timeout_s

        self._progress_callback = progress_callback

    async def begin(self) -> None:
        """Sends the initial chunk of the transfer."""
        self._send_chunk(self._initial_chunk())
        self._response_timer.start(self._initial_response_timeout_s)

    @property
    @abc.abstractmethod
    def data(self) -> bytes:
        """Returns the data read or written in this transfer."""

    @abc.abstractmethod
    def _initial_chunk(self) -> transfer_pb2.Chunk:
        """Returns the initial chunk to notify the sever of the transfer."""

    async def handle_chunk(self, chunk: transfer_pb2.Chunk) -> None:
        """Processes an incoming chunk from the server.

        Handles terminating chunks (i.e. those with a status) and forwards
        non-terminating chunks to handle_data_chunk.
        """
        self._response_timer.stop()
        self._retries = 0  # Received data from service, so reset the retries.

        _LOG.debug('Received chunk\n%s', str(chunk).rstrip())

        # Status chunks are only used to terminate a transfer. They do not
        # contain any data that requires processing.
        if chunk.HasField('status'):
            self.finish(Status(chunk.status))
            return

        await self._handle_data_chunk(chunk)

        # Start the timeout for the server to send a chunk in response.
        self._response_timer.start()

    @abc.abstractmethod
    async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
        """Handles a chunk that contains or requests data."""

    @abc.abstractmethod
    def _retry_after_timeout(self) -> None:
        """Retries after a timeout occurs."""

    def _on_timeout(self) -> None:
        """Handles a timeout while waiting for a chunk."""
        if self._state is Transfer._State.COMPLETE:
            return

        self._retries += 1
        if self._retries > self._max_retries:
            self.finish(Status.DEADLINE_EXCEEDED)
            return

        _LOG.debug('Received no responses for %.3fs; retrying %d/%d',
                   self._response_timer.timeout_s, self._retries,
                   self._max_retries)
        self._retry_after_timeout()
        self._response_timer.start()

    def finish(self, status: Status, skip_callback: bool = False) -> None:
        """Ends the transfer with the specified status."""
        self._response_timer.stop()
        self.status = status

        if status.ok():
            total_size = len(self.data)
            self._update_progress(total_size, total_size, total_size)

        if not skip_callback:
            self._end_transfer(self)

        # Set done last so that the transfer has been fully cleaned up.
        self._state = Transfer._State.COMPLETE
        self.done.set()

    def _update_progress(self, bytes_sent: int, bytes_confirmed_received: int,
                         total_size_bytes: Optional[int]) -> None:
        """Invokes the provided progress callback, if any, with the progress."""

        stats = ProgressStats(bytes_sent, bytes_confirmed_received,
                              total_size_bytes)
        _LOG.debug('Transfer %d progress: %s', self.id, stats)

        if self._progress_callback:
            self._progress_callback(stats)

    def _send_error(self, error: Status) -> None:
        """Sends an error chunk to the server and finishes the transfer."""
        self._send_chunk(
            transfer_pb2.Chunk(transfer_id=self.id,
                               status=error.value,
                               type=transfer_pb2.Chunk.Type.COMPLETION))
        self.finish(error)


class WriteTransfer(Transfer):
    """A client -> server write transfer."""
    def __init__(
        self,
        session_id: int,
        data: bytes,
        send_chunk: Callable[[transfer_pb2.Chunk], None],
        end_transfer: Callable[[Transfer], None],
        response_timeout_s: float,
        initial_response_timeout_s: float,
        max_retries: int,
        progress_callback: ProgressCallback = None,
    ):
        super().__init__(session_id, send_chunk, end_transfer,
                         response_timeout_s, initial_response_timeout_s,
                         max_retries, progress_callback)
        self._data = data

        self._offset = 0
        self._window_end_offset = 0
        self._max_chunk_size = 0
        self._chunk_delay_us: Optional[int] = None

        # The window ID increments for each parameters update.
        self._window_id = 0

        self._bytes_confirmed_received = 0
        self._last_chunk = self._initial_chunk()

    @property
    def data(self) -> bytes:
        return self._data

    def _initial_chunk(self) -> transfer_pb2.Chunk:
        # TODO(b/243679452): In protocol v2, session_id should be assigned by
        # the server during an initial handshake.
        return transfer_pb2.Chunk(transfer_id=self.id,
                                  resource_id=self.id,
                                  type=transfer_pb2.Chunk.Type.START)

    async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
        """Processes an incoming chunk from the server.

        In a write transfer, the server only sends transfer parameter updates
        to the client. When a message is received, update local parameters and
        send data accordingly.
        """

        if self._state is Transfer._State.TERMINATING:
            # TODO(b/243679452): Handle termination in protocol v2.
            return

        if self._state is Transfer._State.TRANSMITTING:
            self._state = Transfer._State.WAITING

        assert self._state is Transfer._State.WAITING

        if not self._handle_parameters_update(chunk):
            return

        self._bytes_confirmed_received = chunk.offset
        self._state = Transfer._State.TRANSMITTING

        self._window_id += 1
        asyncio.create_task(self._transmit_next_chunk(self._window_id))

    async def _transmit_next_chunk(self,
                                   window_id: int,
                                   timeout_us: int = None) -> None:
        """Transmits a single data chunk to the server.

        If the chunk completes the active window, returns to a WAITING state.
        Otherwise, schedules another transmission for the next chunk.
        """
        if timeout_us is not None:
            await asyncio.sleep(timeout_us / 1e6)

        if self._state is not Transfer._State.TRANSMITTING:
            return

        if window_id != self._window_id:
            _LOG.debug('Transfer %d: Skipping stale window', self.id)
            return

        chunk = self._next_chunk()
        self._offset += len(chunk.data)

        sent_requested_bytes = self._offset == self._window_end_offset

        self._send_chunk(chunk)
        self._update_progress(self._offset, self._bytes_confirmed_received,
                              len(self.data))

        if sent_requested_bytes:
            self._state = Transfer._State.WAITING
        else:
            asyncio.create_task(
                self._transmit_next_chunk(window_id,
                                          timeout_us=self._chunk_delay_us))

        self._last_chunk = chunk

    def _handle_parameters_update(self, chunk: transfer_pb2.Chunk) -> bool:
        """Updates transfer state based on a transfer parameters update."""

        retransmit = True
        if chunk.HasField('type'):
            retransmit = (chunk.type
                          == transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT
                          or chunk.type == transfer_pb2.Chunk.Type.START)

        if chunk.offset > len(self.data):
            # Bad offset; terminate the transfer.
            _LOG.error(
                'Transfer %d: server requested invalid offset %d (size %d)',
                self.id, chunk.offset, len(self.data))

            self._send_error(Status.OUT_OF_RANGE)
            return False

        if chunk.pending_bytes == 0:
            _LOG.error(
                'Transfer %d: service requested 0 bytes (invalid); aborting',
                self.id)
            self._send_error(Status.INTERNAL)
            return False

        if retransmit:
            # Check whether the client has sent a previous data offset, which
            # indicates that some chunks were lost in transmission.
            if chunk.offset < self._offset:
                _LOG.debug('Write transfer %d rolling back: offset %d from %d',
                           self.id, chunk.offset, self._offset)

            self._offset = chunk.offset

            # Retransmit is the default behavior for older versions of the
            # transfer protocol. The window_end_offset field is not guaranteed
            # to be set in these version, so it must be calculated.
            max_bytes_to_send = min(chunk.pending_bytes,
                                    len(self.data) - self._offset)
            self._window_end_offset = self._offset + max_bytes_to_send
        else:
            assert chunk.type == transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE

            # Extend the window to the new end offset specified by the server.
            self._window_end_offset = min(chunk.window_end_offset,
                                          len(self.data))

        if chunk.HasField('max_chunk_size_bytes'):
            self._max_chunk_size = chunk.max_chunk_size_bytes

        if chunk.HasField('min_delay_microseconds'):
            self._chunk_delay_us = chunk.min_delay_microseconds

        return True

    def _retry_after_timeout(self) -> None:
        if self._state is Transfer._State.WAITING:
            self._send_chunk(self._last_chunk)

        # TODO(b/243679452): Handle other states in protocol v2.

    def _next_chunk(self) -> transfer_pb2.Chunk:
        """Returns the next Chunk message to send in the data transfer."""
        chunk = transfer_pb2.Chunk(transfer_id=self.id,
                                   offset=self._offset,
                                   type=transfer_pb2.Chunk.Type.DATA)
        max_bytes_in_chunk = min(self._max_chunk_size,
                                 self._window_end_offset - self._offset)

        chunk.data = self.data[self._offset:self._offset + max_bytes_in_chunk]

        # Mark the final chunk of the transfer.
        if len(self.data) - self._offset <= max_bytes_in_chunk:
            chunk.remaining_bytes = 0

        return chunk


class ReadTransfer(Transfer):
    """A client <- server read transfer.

    Although Python can effectively handle an unlimited transfer window, this
    client sets a conservative window and chunk size to avoid overloading the
    device. These are configurable in the constructor.
    """

    # The fractional position within a window at which a receive transfer should
    # extend its window size to minimize the amount of time the transmitter
    # spends blocked.
    #
    # For example, a divisor of 2 will extend the window when half of the
    # requested data has been received, a divisor of three will extend at a
    # third of the window, and so on.
    EXTEND_WINDOW_DIVISOR = 2

    def __init__(  # pylint: disable=too-many-arguments
            self,
            session_id: int,
            send_chunk: Callable[[transfer_pb2.Chunk], None],
            end_transfer: Callable[[Transfer], None],
            response_timeout_s: float,
            initial_response_timeout_s: float,
            max_retries: int,
            max_bytes_to_receive: int = 8192,
            max_chunk_size: int = 1024,
            chunk_delay_us: int = None,
            progress_callback: ProgressCallback = None):
        super().__init__(session_id, send_chunk, end_transfer,
                         response_timeout_s, initial_response_timeout_s,
                         max_retries, progress_callback)
        self._max_bytes_to_receive = max_bytes_to_receive
        self._max_chunk_size = max_chunk_size
        self._chunk_delay_us = chunk_delay_us

        self._remaining_transfer_size: Optional[int] = None
        self._data = bytearray()
        self._offset = 0
        self._pending_bytes = max_bytes_to_receive
        self._window_end_offset = max_bytes_to_receive
        self._last_chunk_offset: Optional[int] = None

    @property
    def data(self) -> bytes:
        """Returns an immutable copy of the data that has been read."""
        return bytes(self._data)

    def _initial_chunk(self) -> transfer_pb2.Chunk:
        return self._transfer_parameters(transfer_pb2.Chunk.Type.START)

    async def _handle_data_chunk(self, chunk: transfer_pb2.Chunk) -> None:
        """Processes an incoming chunk from the server.

        In a read transfer, the client receives data chunks from the server.
        Once all pending data is received, the transfer parameters are updated.
        """

        if self._state is Transfer._State.TERMINATING:
            # TODO(b/243679452): Handle terminating states in protocol v2.
            return

        if self._state is Transfer._State.RECOVERY:
            if chunk.offset != self._offset:
                if self._last_chunk_offset == chunk.offset:
                    _LOG.debug(
                        'Transfer %d received repeated offset %d: '
                        'retry detected, resending transfer parameters',
                        self.id, chunk.offset)
                    self._send_chunk(
                        self._transfer_parameters(
                            transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
                else:
                    _LOG.debug(
                        'Transfer %d waiting for offset %d, ignoring %d',
                        self.id, self._offset, chunk.offset)
                    self._last_chunk_offset = chunk.offset
                return

            _LOG.info(
                'Transfer %d received expected offset %d, resuming transfer',
                self.id, chunk.offset)
            self._state = Transfer._State.WAITING

        assert self._state is Transfer._State.WAITING

        if chunk.offset != self._offset:
            # Initially, the transfer service only supports in-order transfers.
            # If data is received out of order, request that the server
            # retransmit from the previous offset.
            _LOG.debug(
                'Transfer %d expected offset %d, received %d: '
                'entering recovery state', self.id, self._offset, chunk.offset)
            self._state = Transfer._State.RECOVERY

            self._send_chunk(
                self._transfer_parameters(
                    transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
            return

        self._data += chunk.data
        self._pending_bytes -= len(chunk.data)
        self._offset += len(chunk.data)

        # Update the last offset seen so that retries can be detected.
        self._last_chunk_offset = chunk.offset

        if chunk.HasField('remaining_bytes'):
            if chunk.remaining_bytes == 0:
                # No more data to read. Acknowledge receipt and finish.
                self._send_chunk(
                    transfer_pb2.Chunk(
                        transfer_id=self.id,
                        status=Status.OK.value,
                        type=transfer_pb2.Chunk.Type.COMPLETION))
                self.finish(Status.OK)
                return

            # The server may indicate if the amount of remaining data is known.
            self._remaining_transfer_size = chunk.remaining_bytes
        elif self._remaining_transfer_size is not None:
            # Update the remaining transfer size, if it is known.
            self._remaining_transfer_size -= len(chunk.data)

            # If the transfer size drops to zero, the estimate was inaccurate.
            if self._remaining_transfer_size <= 0:
                self._remaining_transfer_size = None

        total_size = None if self._remaining_transfer_size is None else (
            self._remaining_transfer_size + self._offset)
        self._update_progress(self._offset, self._offset, total_size)

        if chunk.window_end_offset != 0:
            if chunk.window_end_offset < self._offset:
                _LOG.error(
                    'Transfer %d: transmitter sent invalid earlier end offset '
                    '%d (receiver offset %d)', self.id,
                    chunk.window_end_offset, self._offset)
                self._send_error(Status.INTERNAL)
                return

            if chunk.window_end_offset > self._window_end_offset:
                _LOG.error(
                    'Transfer %d: transmitter sent invalid later end offset '
                    '%d (receiver end offset %d)', self.id,
                    chunk.window_end_offset, self._window_end_offset)
                self._send_error(Status.INTERNAL)
                return

            self._window_end_offset = chunk.window_end_offset
            self._pending_bytes -= chunk.window_end_offset - self._offset

        remaining_window_size = self._window_end_offset - self._offset
        extend_window = (remaining_window_size <= self._max_bytes_to_receive /
                         ReadTransfer.EXTEND_WINDOW_DIVISOR)

        if self._pending_bytes == 0:
            # All pending data was received. Send out a new parameters chunk for
            # the next block.
            self._send_chunk(
                self._transfer_parameters(
                    transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))
        elif extend_window:
            self._send_chunk(
                self._transfer_parameters(
                    transfer_pb2.Chunk.Type.PARAMETERS_CONTINUE))

    def _retry_after_timeout(self) -> None:
        if (self._state is Transfer._State.WAITING
                or self._state is Transfer._State.RECOVERY):
            self._send_chunk(
                self._transfer_parameters(
                    transfer_pb2.Chunk.Type.PARAMETERS_RETRANSMIT))

        # TODO(b/243679452): Handle other states in protocol v2.

    def _transfer_parameters(self, chunk_type: Any) -> transfer_pb2.Chunk:
        """Sends an updated transfer parameters chunk to the server."""

        self._pending_bytes = self._max_bytes_to_receive
        self._window_end_offset = self._offset + self._max_bytes_to_receive

        chunk = transfer_pb2.Chunk(transfer_id=self.id,
                                   pending_bytes=self._pending_bytes,
                                   window_end_offset=self._window_end_offset,
                                   max_chunk_size_bytes=self._max_chunk_size,
                                   offset=self._offset,
                                   type=chunk_type)

        if self._chunk_delay_us:
            chunk.min_delay_microseconds = self._chunk_delay_us

        return chunk
