# Copyright 2024 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests for gonk binary data"""

import unittest

from gonk_tools.binary_handler import (
    BIN_LOG_SYNC_START_BYTES,
    BIN_LOG_SYNC_START_STR,
    BytearrayLoop,
    decode_one_varint,
)
from gonk_tools.gonk_log_stream import GonkLogStream

# pylint: disable=line-too-long
_SAMPLE_ENCODED_LOGS = [
    b'~$llsrqA==E\xc4\xe2\x93~',
    b'~$w4ZyKAI=0\x8f\xbf!~',
    b'~$w4ZyKAQ=i\xa4\xa3~',
    b'~$w4ZyKAY=a\x9d}]k~',
    b'~$w4ZyKAg=\x86\xb8*~',
    b'~$YNjZuQ==X\xb2\xc1@~',
    b'~$J+09dYBA\xb2\xf8^~',
]

_SAMPLE_PROTO_PACKETS = [
    bytes.fromhex(
        f'{BIN_LOG_SYNC_START_STR}122310d380ef020a04080010030a04080010020a040800104e0a040845104f0a0408001002'
    ),
    bytes.fromhex(
        f'{BIN_LOG_SYNC_START_STR}122110bf010a04080010030a04080010020a040800104e0a04080010020a040800104f'
    ),
    bytes.fromhex(
        f'{BIN_LOG_SYNC_START_STR}122a10b9010a04080010030a04080010020a04080010030a04080010020a0d080010b6ffffffffffffffff01'
    ),
    bytes.fromhex(
        f'{BIN_LOG_SYNC_START_STR}122a10d1010a04080010030a04080010020a04084710030a0d080010b7ffffffffffffffff010a0408001002'
    ),
    bytes.fromhex(
        f'{BIN_LOG_SYNC_START_STR}123310df010a04080010030a0d080010b6ffffffffffffffff010a04080010030a0d080010b7ffffffffffffffff010a0408001002'
    ),
    bytes.fromhex(
        f'{BIN_LOG_SYNC_START_STR}127210ffffffff0f0a04080010030a04080010020a04080010030a04080010020a04080010020a0d080010b7ffffffffffffffff010a0d080010b7ffffffffffffffff010a04084710020a04084710030a04080010030a04080010020a1608ffffffffffffffffff0110ffffffffffffffffff01'
    ),
    bytes.fromhex(
        f'{BIN_LOG_SYNC_START_STR}127b10ffffffff0f0a04080010030a04080010020a04080010030a04080010020a0d080010b6ffffffffffffffff010a0d080010b7ffffffffffffffff010a0d080010b7ffffffffffffffff010a04080010020a04080010030a04080010030a04080010020a1608ffffffffffffffffff0110ffffffffffffffffff01'
    ),
    bytes.fromhex(
        f'{BIN_LOG_SYNC_START_STR}127210ffffffff0f0a04080010030a04080010020a04080010030a04084510020a04080010020a04080010020a0d080010b7ffffffffffffffff010a04084610020a04080010030a04080010030a0d084810b6ffffffffffffffff010a1608ffffffffffffffffff0110ffffffffffffffffff01'
    ),
]
_SAMPLE_LONG_PROTO_PACKETS = [
    # Artificially increased size > 127 (uses two bytes for the packet size varint)
    bytes.fromhex(
        f'{BIN_LOG_SYNC_START_STR}12F20110ffffffff0f0a04080010030a04080010020a04080010030a04084510020a04080010020a04080010020a0d080010b7ffffffffffffffff010a04084610020a04080010030a04080010030a0d084810b6ffffffffffffffff010a1608ffffffffffffffffff0110ffffffffffffffffff0110ffffffff0f0a04080010030a04080010020a04080010030a04084510020a04080010020a04080010020a0d080010b7ffffffffffffffff010a04084610020a04080010030a04080010030a0d084810b6ffffffffffffffff010a1608ffffffffffffffffff0110ffffffffffffffffff01ffffffffffffffffff01ffffffff'
    ),
]


# pylint: enable=line-too-long
class TestHandleBinaryData(unittest.TestCase):
    """Tests for binary log data."""
    def setUp(self):
        # Show large diffs
        self.maxDiff = None  # pylint: disable=invalid-name

    def test_gonk_stream_write(self) -> None:
        """Test GonkLogStream.write handles protos and logs correctly."""
        log_data = b''.join([
            _SAMPLE_ENCODED_LOGS[1],
        ])
        proto_data1 = _SAMPLE_LONG_PROTO_PACKETS[0]

        gl = GonkLogStream()
        gl.write(log_data)
        # No proto bytes yet.
        self.assertEqual(None, gl.remaining_proto_bytes)

        gl.write(proto_data1 + log_data)
        # One log data section should be in the log_buffer
        self.assertEqual(log_data, gl.log_buffer.getbuffer())
        # One proto packet should be consumed.
        self.assertEqual(None, gl.remaining_proto_bytes)
        self.assertEqual(log_data, gl.input_buffer.getbuffer())

        # Write a third log
        gl.write(log_data)
        # Log buffer should have 3 logs
        self.assertEqual(log_data * 3, gl.log_buffer.getbuffer())
        # Input buffer now empty
        self.assertEqual(b'', gl.input_buffer.getbuffer())

    def test_read_until(self) -> None:
        """Test BytearrayLoop.read_until."""
        log_data = b''.join([
            _SAMPLE_ENCODED_LOGS[1],
            _SAMPLE_ENCODED_LOGS[2],
        ])
        proto_data = b''.join([
            _SAMPLE_PROTO_PACKETS[0],
            _SAMPLE_PROTO_PACKETS[1],
        ])

        with self.subTest(
                msg="Matching sub sequence exists; read until that."):
            merged_data = log_data + proto_data
            bl = BytearrayLoop(merged_data)
            self.assertEqual(log_data, bl.read_until(BIN_LOG_SYNC_START_BYTES))
            self.assertEqual(proto_data, bl.getbuffer())

        with self.subTest(
                msg="No matching sub sequence exists; read all data."):
            bl = BytearrayLoop(log_data)
            self.assertEqual(log_data, bl.read_until(BIN_LOG_SYNC_START_BYTES))
            # Remaining buffer should be empty
            self.assertEqual(b'', bl.getbuffer())

        with self.subTest(
                msg="Matching sub sequence at the start, read nothing."):
            bl = BytearrayLoop(proto_data)
            self.assertEqual(b'', bl.read_until(BIN_LOG_SYNC_START_BYTES))
            # Remaining buffer should be full
            self.assertEqual(proto_data, bl.getbuffer())

        with self.subTest(msg="Subsequence is longer than the buffer."):
            # Searching for a subsequece longer than the buffer should return an
            # empty result.
            too_short_buffer = BytearrayLoop(bytes.fromhex('abc123'))
            self.assertEqual(
                b'',
                too_short_buffer.read_until(bytes.fromhex('abc123456789')))

    def test_proto_packet_length(self) -> None:
        """Test decode_one_varint on remaining protobuf lengths."""
        for packet in _SAMPLE_PROTO_PACKETS:
            bl = BytearrayLoop()

            bl.write(packet)
            self.assertEqual(packet, bl.getbuffer())

            original_size = len(packet)
            packet_size = original_size

            self.assertEqual(BIN_LOG_SYNC_START_BYTES, bl.peek(5))
            # Read magic bytes fixed32 tag + 4 bytes.
            magic_start = bl.read(5)
            self.assertEqual(BIN_LOG_SYNC_START_BYTES, magic_start)
            packet_size -= 5

            # Read the tag.
            _tag = bl.read(1)
            packet_size -= 1

            # Read the varint with the remaining packet size.
            original_size = len(bl.getbuffer())
            decoded_remaining_size = decode_one_varint(bl)
            # Subtract bytes used by the varint
            packet_size -= original_size - len(bl.getbuffer())

            self.assertEqual(decoded_remaining_size, packet_size)


if __name__ == '__main__':
    unittest.main()
