#!/usr/bin/env python3
# Copyright 2021 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Tests using the callback client for pw_rpc."""

import unittest
from unittest import mock
from typing import Any, List, Optional, Tuple

from pw_protobuf_compiler import python_protos
from pw_status import Status

from pw_rpc import callback_client, client, packets
from pw_rpc.internal import packet_pb2

TEST_PROTO_1 = """\
syntax = "proto3";

package pw.test1;

message SomeMessage {
  uint32 magic_number = 1;
}

message AnotherMessage {
  enum Result {
    FAILED = 0;
    FAILED_MISERABLY = 1;
    I_DONT_WANT_TO_TALK_ABOUT_IT = 2;
  }

  Result result = 1;
  string payload = 2;
}

service PublicService {
  rpc SomeUnary(SomeMessage) returns (AnotherMessage) {}
  rpc SomeServerStreaming(SomeMessage) returns (stream AnotherMessage) {}
  rpc SomeClientStreaming(stream SomeMessage) returns (AnotherMessage) {}
  rpc SomeBidiStreaming(stream SomeMessage) returns (stream AnotherMessage) {}
}
"""


def _message_bytes(msg) -> bytes:
    return msg if isinstance(msg, bytes) else msg.SerializeToString()


class _CallbackClientImplTestBase(unittest.TestCase):
    """Supports writing tests that require responses from an RPC server."""
    def setUp(self) -> None:
        self._protos = python_protos.Library.from_strings(TEST_PROTO_1)
        self._request = self._protos.packages.pw.test1.SomeMessage

        self._client = client.Client.from_modules(
            callback_client.Impl(), [client.Channel(1, self._handle_packet)],
            self._protos.modules())
        self._service = self._client.channel(1).rpcs.pw.test1.PublicService

        self.requests: List[packet_pb2.RpcPacket] = []
        self._next_packets: List[Tuple[bytes, Status]] = []
        self.send_responses_after_packets: float = 1

        self.output_exception: Optional[Exception] = None

    def last_request(self) -> packet_pb2.RpcPacket:
        assert self.requests
        return self.requests[-1]

    def _enqueue_response(self,
                          channel_id: int,
                          method=None,
                          status: Status = Status.OK,
                          payload=b'',
                          *,
                          ids: Tuple[int, int] = None,
                          process_status=Status.OK) -> None:
        if method:
            assert ids is None
            service_id, method_id = method.service.id, method.id
        else:
            assert ids is not None and method is None
            service_id, method_id = ids

        self._next_packets.append((packet_pb2.RpcPacket(
            type=packet_pb2.PacketType.RESPONSE,
            channel_id=channel_id,
            service_id=service_id,
            method_id=method_id,
            status=status.value,
            payload=_message_bytes(payload)).SerializeToString(),
                                   process_status))

    def _enqueue_server_stream(self,
                               channel_id: int,
                               method,
                               response,
                               process_status=Status.OK) -> None:
        self._next_packets.append((packet_pb2.RpcPacket(
            type=packet_pb2.PacketType.SERVER_STREAM,
            channel_id=channel_id,
            service_id=method.service.id,
            method_id=method.id,
            payload=_message_bytes(response)).SerializeToString(),
                                   process_status))

    def _enqueue_error(self,
                       channel_id: int,
                       service,
                       method,
                       status: Status,
                       process_status=Status.OK) -> None:
        self._next_packets.append((packet_pb2.RpcPacket(
            type=packet_pb2.PacketType.SERVER_ERROR,
            channel_id=channel_id,
            service_id=service if isinstance(service, int) else service.id,
            method_id=method if isinstance(method, int) else method.id,
            status=status.value).SerializeToString(), process_status))

    def _handle_packet(self, data: bytes) -> None:
        if self.output_exception:
            raise self.output_exception  # pylint: disable=raising-bad-type

        self.requests.append(packets.decode(data))

        if self.send_responses_after_packets > 1:
            self.send_responses_after_packets -= 1
            return

        self._process_enqueued_packets()

    def _process_enqueued_packets(self) -> None:
        # Set send_responses_after_packets to infinity to prevent potential
        # infinite recursion when a packet causes another packet to send.
        send_after_count = self.send_responses_after_packets
        self.send_responses_after_packets = float('inf')

        for packet, status in self._next_packets:
            self.assertIs(status, self._client.process_packet(packet))

        self._next_packets.clear()
        self.send_responses_after_packets = send_after_count

    def _sent_payload(self, message_type: type) -> Any:
        message = message_type()
        message.ParseFromString(self.last_request().payload)
        return message


class CallbackClientImplTest(_CallbackClientImplTestBase):
    """Tests the callback_client.Impl client implementation."""
    def test_callback_exceptions_suppressed(self) -> None:
        stub = self._service.SomeUnary

        self._enqueue_response(1, stub.method)
        exception_msg = 'YOU BROKE IT O-]-<'

        with self.assertLogs(callback_client.__package__, 'ERROR') as logs:
            stub.invoke(self._request(),
                        mock.Mock(side_effect=Exception(exception_msg)))

        self.assertIn(exception_msg, ''.join(logs.output))

        # Make sure we can still invoke the RPC.
        self._enqueue_response(1, stub.method, Status.UNKNOWN)
        status, _ = stub()
        self.assertIs(status, Status.UNKNOWN)

    def test_ignore_bad_packets_with_pending_rpc(self) -> None:
        method = self._service.SomeUnary.method
        service_id = method.service.id

        # Unknown channel
        self._enqueue_response(999, method, process_status=Status.NOT_FOUND)
        # Bad service
        self._enqueue_response(1,
                               ids=(999, method.id),
                               process_status=Status.OK)
        # Bad method
        self._enqueue_response(1,
                               ids=(service_id, 999),
                               process_status=Status.OK)
        # For RPC not pending (is Status.OK because the packet is processed)
        self._enqueue_response(1,
                               ids=(service_id,
                                    self._service.SomeBidiStreaming.method.id),
                               process_status=Status.OK)

        self._enqueue_response(1, method, process_status=Status.OK)

        status, response = self._service.SomeUnary(magic_number=6)
        self.assertIs(Status.OK, status)
        self.assertEqual('', response.payload)

    def test_server_error_for_unknown_call_sends_no_errors(self) -> None:
        method = self._service.SomeUnary.method
        service_id = method.service.id

        # Unknown channel
        self._enqueue_error(999,
                            service_id,
                            method,
                            Status.NOT_FOUND,
                            process_status=Status.NOT_FOUND)
        # Bad service
        self._enqueue_error(1, 999, method.id, Status.INVALID_ARGUMENT)
        # Bad method
        self._enqueue_error(1, service_id, 999, Status.INVALID_ARGUMENT)
        # For RPC not pending
        self._enqueue_error(1, service_id,
                            self._service.SomeBidiStreaming.method.id,
                            Status.NOT_FOUND)

        self._process_enqueued_packets()

        self.assertEqual(self.requests, [])

    def test_exception_if_payload_fails_to_decode(self) -> None:
        method = self._service.SomeUnary.method

        self._enqueue_response(1,
                               method,
                               Status.OK,
                               b'INVALID DATA!!!',
                               process_status=Status.OK)

        with self.assertRaises(callback_client.RpcError) as context:
            self._service.SomeUnary(magic_number=6)

        self.assertIs(context.exception.status, Status.DATA_LOSS)

    def test_rpc_help_contains_method_name(self) -> None:
        rpc = self._service.SomeUnary
        self.assertIn(rpc.method.full_name, rpc.help())

    def test_default_timeouts_set_on_impl(self) -> None:
        impl = callback_client.Impl(None, 1.5)

        self.assertEqual(impl.default_unary_timeout_s, None)
        self.assertEqual(impl.default_stream_timeout_s, 1.5)

    def test_default_timeouts_set_for_all_rpcs(self) -> None:
        rpc_client = client.Client.from_modules(callback_client.Impl(
            99, 100), [client.Channel(1, lambda *a, **b: None)],
                                                self._protos.modules())
        rpcs = rpc_client.channel(1).rpcs

        self.assertEqual(
            rpcs.pw.test1.PublicService.SomeUnary.default_timeout_s, 99)
        self.assertEqual(
            rpcs.pw.test1.PublicService.SomeServerStreaming.default_timeout_s,
            100)
        self.assertEqual(
            rpcs.pw.test1.PublicService.SomeClientStreaming.default_timeout_s,
            99)
        self.assertEqual(
            rpcs.pw.test1.PublicService.SomeBidiStreaming.default_timeout_s,
            100)

    def test_rpc_provides_request_type(self) -> None:
        self.assertIs(self._service.SomeUnary.request,
                      self._service.SomeUnary.method.request_type)

    def test_rpc_provides_response_type(self) -> None:
        self.assertIs(self._service.SomeUnary.request,
                      self._service.SomeUnary.method.request_type)


class UnaryTest(_CallbackClientImplTestBase):
    """Tests for invoking a unary RPC."""
    def setUp(self) -> None:
        super().setUp()
        self.rpc = self._service.SomeUnary
        self.method = self.rpc.method

    def test_blocking_call(self) -> None:
        for _ in range(3):
            self._enqueue_response(1, self.method, Status.ABORTED,
                                   self.method.response_type(payload='0_o'))

            status, response = self._service.SomeUnary(
                self.method.request_type(magic_number=6))

            self.assertEqual(
                6,
                self._sent_payload(self.method.request_type).magic_number)

            self.assertIs(Status.ABORTED, status)
            self.assertEqual('0_o', response.payload)

    def test_nonblocking_call(self) -> None:
        for _ in range(3):
            self._enqueue_response(1, self.method, Status.ABORTED,
                                   self.method.response_type(payload='0_o'))

            callback = mock.Mock()
            call = self.rpc.invoke(self._request(magic_number=5), callback,
                                   callback)

            callback.assert_has_calls([
                mock.call(call, self.method.response_type(payload='0_o')),
                mock.call(call, Status.ABORTED)
            ])

            self.assertEqual(
                5,
                self._sent_payload(self.method.request_type).magic_number)

    def test_open(self) -> None:
        self.output_exception = IOError('something went wrong sending!')

        for _ in range(3):
            self._enqueue_response(1, self.method, Status.ABORTED,
                                   self.method.response_type(payload='0_o'))

            callback = mock.Mock()
            call = self.rpc.open(self._request(magic_number=5), callback,
                                 callback)
            self.assertEqual(self.requests, [])

            self._process_enqueued_packets()

            callback.assert_has_calls([
                mock.call(call, self.method.response_type(payload='0_o')),
                mock.call(call, Status.ABORTED)
            ])

    def test_blocking_server_error(self) -> None:
        for _ in range(3):
            self._enqueue_error(1, self.method.service, self.method,
                                Status.NOT_FOUND)

            with self.assertRaises(callback_client.RpcError) as context:
                self._service.SomeUnary(
                    self.method.request_type(magic_number=6))

            self.assertIs(context.exception.status, Status.NOT_FOUND)

    def test_nonblocking_cancel(self) -> None:
        callback = mock.Mock()

        for _ in range(3):
            call = self._service.SomeUnary.invoke(
                self._request(magic_number=55), callback)

            self.assertGreater(len(self.requests), 0)
            self.requests.clear()

            self.assertTrue(call.cancel())
            self.assertFalse(call.cancel())  # Already cancelled, returns False

            # Unary RPCs do not send a cancel request to the server.
            self.assertFalse(self.requests)

        callback.assert_not_called()

    def test_nonblocking_with_request_args(self) -> None:
        self.rpc.invoke(request_args=dict(magic_number=1138))
        self.assertEqual(
            self._sent_payload(self.rpc.request).magic_number, 1138)

    def test_blocking_timeout_as_argument(self) -> None:
        with self.assertRaises(callback_client.RpcTimeout):
            self._service.SomeUnary(pw_rpc_timeout_s=0.0001)

    def test_blocking_timeout_set_default(self) -> None:
        self._service.SomeUnary.default_timeout_s = 0.0001

        with self.assertRaises(callback_client.RpcTimeout):
            self._service.SomeUnary()

    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
        first_call = self.rpc.invoke()
        self.assertFalse(first_call.completed())

        second_call = self.rpc.invoke()

        self.assertIs(first_call.error, Status.CANCELLED)
        self.assertFalse(second_call.completed())

    def test_nonblocking_exception_in_callback(self) -> None:
        exception = ValueError('something went wrong!')

        self._enqueue_response(1, self.method, Status.OK)

        call = self.rpc.invoke(on_completed=mock.Mock(side_effect=exception))

        with self.assertRaises(RuntimeError) as context:
            call.wait()

        self.assertEqual(context.exception.__cause__, exception)


class ServerStreamingTest(_CallbackClientImplTestBase):
    """Tests for server streaming RPCs."""
    def setUp(self) -> None:
        super().setUp()
        self.rpc = self._service.SomeServerStreaming
        self.method = self.rpc.method

    def test_blocking_call(self) -> None:
        rep1 = self.method.response_type(payload='!!!')
        rep2 = self.method.response_type(payload='?')

        for _ in range(3):
            self._enqueue_server_stream(1, self.method, rep1)
            self._enqueue_server_stream(1, self.method, rep2)
            self._enqueue_response(1, self.method, Status.ABORTED)

            self.assertEqual(
                [rep1, rep2],
                self._service.SomeServerStreaming(magic_number=4).responses)

            self.assertEqual(
                4,
                self._sent_payload(self.method.request_type).magic_number)

    def test_deprecated_packet_format(self) -> None:
        rep1 = self.method.response_type(payload='!!!')
        rep2 = self.method.response_type(payload='?')

        for _ in range(3):
            # The original packet format used RESPONSE packets for the server
            # stream and a SERVER_STREAM_END packet as the last packet. These
            # are converted to SERVER_STREAM packets followed by a RESPONSE.
            self._enqueue_response(1, self.method, payload=rep1)
            self._enqueue_response(1, self.method, payload=rep2)

            self._next_packets.append((packet_pb2.RpcPacket(
                type=packet_pb2.PacketType.DEPRECATED_SERVER_STREAM_END,
                channel_id=1,
                service_id=self.method.service.id,
                method_id=self.method.id,
                status=Status.INVALID_ARGUMENT.value).SerializeToString(),
                                       Status.OK))

            status, replies = self._service.SomeServerStreaming(magic_number=4)
            self.assertEqual([rep1, rep2], replies)
            self.assertIs(status, Status.INVALID_ARGUMENT)

            self.assertEqual(
                4,
                self._sent_payload(self.method.request_type).magic_number)

    def test_nonblocking_call(self) -> None:
        rep1 = self.method.response_type(payload='!!!')
        rep2 = self.method.response_type(payload='?')

        for _ in range(3):
            self._enqueue_server_stream(1, self.method, rep1)
            self._enqueue_server_stream(1, self.method, rep2)
            self._enqueue_response(1, self.method, Status.ABORTED)

            callback = mock.Mock()
            call = self.rpc.invoke(self._request(magic_number=3), callback,
                                   callback)

            callback.assert_has_calls([
                mock.call(call, self.method.response_type(payload='!!!')),
                mock.call(call, self.method.response_type(payload='?')),
                mock.call(call, Status.ABORTED),
            ])

            self.assertEqual(
                3,
                self._sent_payload(self.method.request_type).magic_number)

    def test_open(self) -> None:
        self.output_exception = IOError('something went wrong sending!')
        rep1 = self.method.response_type(payload='!!!')
        rep2 = self.method.response_type(payload='?')

        for _ in range(3):
            self._enqueue_server_stream(1, self.method, rep1)
            self._enqueue_server_stream(1, self.method, rep2)
            self._enqueue_response(1, self.method, Status.ABORTED)

            callback = mock.Mock()
            call = self.rpc.open(self._request(magic_number=3), callback,
                                 callback)
            self.assertEqual(self.requests, [])

            self._process_enqueued_packets()

            callback.assert_has_calls([
                mock.call(call, self.method.response_type(payload='!!!')),
                mock.call(call, self.method.response_type(payload='?')),
                mock.call(call, Status.ABORTED),
            ])

    def test_nonblocking_cancel(self) -> None:
        resp = self.rpc.method.response_type(payload='!!!')
        self._enqueue_server_stream(1, self.rpc.method, resp)

        callback = mock.Mock()
        call = self.rpc.invoke(self._request(magic_number=3), callback)
        callback.assert_called_once_with(
            call, self.rpc.method.response_type(payload='!!!'))

        callback.reset_mock()

        call.cancel()

        self.assertEqual(self.last_request().type,
                         packet_pb2.PacketType.CANCEL)

        # Ensure the RPC can be called after being cancelled.
        self._enqueue_server_stream(1, self.method, resp)
        self._enqueue_response(1, self.method, Status.OK)

        call = self.rpc.invoke(self._request(magic_number=3), callback,
                               callback)

        callback.assert_has_calls([
            mock.call(call, self.method.response_type(payload='!!!')),
            mock.call(call, Status.OK),
        ])

    def test_nonblocking_with_request_args(self) -> None:
        self.rpc.invoke(request_args=dict(magic_number=1138))
        self.assertEqual(
            self._sent_payload(self.rpc.request).magic_number, 1138)

    def test_blocking_timeout(self) -> None:
        with self.assertRaises(callback_client.RpcTimeout):
            self._service.SomeServerStreaming(pw_rpc_timeout_s=0.0001)

    def test_nonblocking_iteration_timeout(self) -> None:
        call = self._service.SomeServerStreaming.invoke(timeout_s=0.0001)
        with self.assertRaises(callback_client.RpcTimeout):
            for _ in call:
                pass

    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
        first_call = self.rpc.invoke()
        self.assertFalse(first_call.completed())

        second_call = self.rpc.invoke()

        self.assertIs(first_call.error, Status.CANCELLED)
        self.assertFalse(second_call.completed())

    def test_nonblocking_iterate_over_count(self) -> None:
        reply = self.method.response_type(payload='!?')

        for _ in range(4):
            self._enqueue_server_stream(1, self.method, reply)

        call = self.rpc.invoke()

        self.assertEqual(list(call.get_responses(count=1)), [reply])
        self.assertEqual(next(iter(call)), reply)
        self.assertEqual(list(call.get_responses(count=2)), [reply, reply])

    def test_nonblocking_iterate_after_completed_doesnt_block(self) -> None:
        reply = self.method.response_type(payload='!?')
        self._enqueue_server_stream(1, self.method, reply)
        self._enqueue_response(1, self.method, Status.OK)

        call = self.rpc.invoke()

        self.assertEqual(list(call.get_responses()), [reply])
        self.assertEqual(list(call.get_responses()), [])
        self.assertEqual(list(call), [])


class ClientStreamingTest(_CallbackClientImplTestBase):
    """Tests for client streaming RPCs."""
    def setUp(self) -> None:
        super().setUp()
        self.rpc = self._service.SomeClientStreaming
        self.method = self.rpc.method

    def test_blocking_call(self) -> None:
        requests = [
            self.method.request_type(magic_number=123),
            self.method.request_type(magic_number=456),
        ]

        # Send after len(requests) and the client stream end packet.
        self.send_responses_after_packets = 3
        response = self.method.response_type(payload='yo')
        self._enqueue_response(1, self.method, Status.OK, response)

        results = self.rpc(requests)
        self.assertIs(results.status, Status.OK)
        self.assertEqual(results.response, response)

    def test_blocking_server_error(self) -> None:
        requests = [self.method.request_type(magic_number=123)]

        # Send after len(requests) and the client stream end packet.
        self._enqueue_error(1, self.method.service, self.method,
                            Status.NOT_FOUND)

        with self.assertRaises(callback_client.RpcError) as context:
            self.rpc(requests)

        self.assertIs(context.exception.status, Status.NOT_FOUND)

    def test_nonblocking_call(self) -> None:
        """Tests a successful client streaming RPC ended by the server."""
        payload_1 = self.method.response_type(payload='-_-')

        for _ in range(3):
            stream = self._service.SomeClientStreaming.invoke()
            self.assertFalse(stream.completed())

            stream.send(magic_number=31)
            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM,
                          self.last_request().type)
            self.assertEqual(
                31,
                self._sent_payload(self.method.request_type).magic_number)
            self.assertFalse(stream.completed())

            # Enqueue the server response to be sent after the next message.
            self._enqueue_response(1, self.method, Status.OK, payload_1)

            stream.send(magic_number=32)
            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM,
                          self.last_request().type)
            self.assertEqual(
                32,
                self._sent_payload(self.method.request_type).magic_number)

            self.assertTrue(stream.completed())
            self.assertIs(Status.OK, stream.status)
            self.assertIsNone(stream.error)
            self.assertEqual(payload_1, stream.response)

    def test_open(self) -> None:
        self.output_exception = IOError('something went wrong sending!')
        payload = self.method.response_type(payload='-_-')

        for _ in range(3):
            self._enqueue_response(1, self.method, Status.OK, payload)

            callback = mock.Mock()
            call = self.rpc.open(callback, callback, callback)
            self.assertEqual(self.requests, [])

            self._process_enqueued_packets()

            callback.assert_has_calls([
                mock.call(call, payload),
                mock.call(call, Status.OK),
            ])

    def test_nonblocking_finish(self) -> None:
        """Tests a client streaming RPC ended by the client."""
        payload_1 = self.method.response_type(payload='-_-')

        for _ in range(3):
            stream = self._service.SomeClientStreaming.invoke()
            self.assertFalse(stream.completed())

            stream.send(magic_number=37)
            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM,
                          self.last_request().type)
            self.assertEqual(
                37,
                self._sent_payload(self.method.request_type).magic_number)
            self.assertFalse(stream.completed())

            # Enqueue the server response to be sent after the next message.
            self._enqueue_response(1, self.method, Status.OK, payload_1)

            stream.finish_and_wait()
            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM_END,
                          self.last_request().type)

            self.assertTrue(stream.completed())
            self.assertIs(Status.OK, stream.status)
            self.assertIsNone(stream.error)
            self.assertEqual(payload_1, stream.response)

    def test_nonblocking_cancel(self) -> None:
        for _ in range(3):
            stream = self._service.SomeClientStreaming.invoke()
            stream.send(magic_number=37)

            self.assertTrue(stream.cancel())
            self.assertIs(packet_pb2.PacketType.CANCEL,
                          self.last_request().type)
            self.assertFalse(stream.cancel())

            self.assertTrue(stream.completed())
            self.assertIs(stream.error, Status.CANCELLED)

    def test_nonblocking_server_error(self) -> None:
        for _ in range(3):
            stream = self._service.SomeClientStreaming.invoke()

            self._enqueue_error(1, self.method.service, self.method,
                                Status.INVALID_ARGUMENT)
            stream.send(magic_number=2**32 - 1)

            with self.assertRaises(callback_client.RpcError) as context:
                stream.finish_and_wait()

            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)

    def test_nonblocking_server_error_after_stream_end(self) -> None:
        for _ in range(3):
            stream = self._service.SomeClientStreaming.invoke()

            # Error will be sent in response to the CLIENT_STREAM_END packet.
            self._enqueue_error(1, self.method.service, self.method,
                                Status.INVALID_ARGUMENT)

            with self.assertRaises(callback_client.RpcError) as context:
                stream.finish_and_wait()

            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)

    def test_nonblocking_send_after_cancelled(self) -> None:
        call = self._service.SomeClientStreaming.invoke()
        self.assertTrue(call.cancel())

        with self.assertRaises(callback_client.RpcError) as context:
            call.send(payload='hello')

        self.assertIs(context.exception.status, Status.CANCELLED)

    def test_nonblocking_finish_after_completed(self) -> None:
        reply = self.method.response_type(payload='!?')
        self._enqueue_response(1, self.method, Status.UNAVAILABLE, reply)

        call = self.rpc.invoke()
        result = call.finish_and_wait()
        self.assertEqual(result.response, reply)

        self.assertEqual(result, call.finish_and_wait())
        self.assertEqual(result, call.finish_and_wait())

    def test_nonblocking_finish_after_error(self) -> None:
        self._enqueue_error(1, self.method.service, self.method,
                            Status.UNAVAILABLE)

        call = self.rpc.invoke()

        for _ in range(3):
            with self.assertRaises(callback_client.RpcError) as context:
                call.finish_and_wait()

            self.assertIs(context.exception.status, Status.UNAVAILABLE)
            self.assertIs(call.error, Status.UNAVAILABLE)
            self.assertIsNone(call.response)

    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
        first_call = self.rpc.invoke()
        self.assertFalse(first_call.completed())

        second_call = self.rpc.invoke()

        self.assertIs(first_call.error, Status.CANCELLED)
        self.assertFalse(second_call.completed())


class BidirectionalStreamingTest(_CallbackClientImplTestBase):
    """Tests for bidirectional streaming RPCs."""
    def setUp(self) -> None:
        super().setUp()
        self.rpc = self._service.SomeBidiStreaming
        self.method = self.rpc.method

    def test_blocking_call(self) -> None:
        requests = [
            self.method.request_type(magic_number=123),
            self.method.request_type(magic_number=456),
        ]

        # Send after len(requests) and the client stream end packet.
        self.send_responses_after_packets = 3
        self._enqueue_response(1, self.method, Status.NOT_FOUND)

        results = self.rpc(requests)
        self.assertIs(results.status, Status.NOT_FOUND)
        self.assertFalse(results.responses)

    def test_blocking_server_error(self) -> None:
        requests = [self.method.request_type(magic_number=123)]

        # Send after len(requests) and the client stream end packet.
        self._enqueue_error(1, self.method.service, self.method,
                            Status.NOT_FOUND)

        with self.assertRaises(callback_client.RpcError) as context:
            self.rpc(requests)

        self.assertIs(context.exception.status, Status.NOT_FOUND)

    def test_nonblocking_call(self) -> None:
        """Tests a bidirectional streaming RPC ended by the server."""
        rep1 = self.method.response_type(payload='!!!')
        rep2 = self.method.response_type(payload='?')

        for _ in range(3):
            responses: list = []
            stream = self._service.SomeBidiStreaming.invoke(
                lambda _, res, responses=responses: responses.append(res))
            self.assertFalse(stream.completed())

            stream.send(magic_number=55)
            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM,
                          self.last_request().type)
            self.assertEqual(
                55,
                self._sent_payload(self.method.request_type).magic_number)
            self.assertFalse(stream.completed())
            self.assertEqual([], responses)

            self._enqueue_server_stream(1, self.method, rep1)
            self._enqueue_server_stream(1, self.method, rep2)

            stream.send(magic_number=66)
            self.assertIs(packet_pb2.PacketType.CLIENT_STREAM,
                          self.last_request().type)
            self.assertEqual(
                66,
                self._sent_payload(self.method.request_type).magic_number)
            self.assertFalse(stream.completed())
            self.assertEqual([rep1, rep2], responses)

            self._enqueue_response(1, self.method, Status.OK)

            stream.send(magic_number=77)
            self.assertTrue(stream.completed())
            self.assertEqual([rep1, rep2], responses)

            self.assertIs(Status.OK, stream.status)
            self.assertIsNone(stream.error)

    def test_open(self) -> None:
        self.output_exception = IOError('something went wrong sending!')
        rep1 = self.method.response_type(payload='!!!')
        rep2 = self.method.response_type(payload='?')

        for _ in range(3):
            self._enqueue_server_stream(1, self.method, rep1)
            self._enqueue_server_stream(1, self.method, rep2)
            self._enqueue_response(1, self.method, Status.OK)

            callback = mock.Mock()
            call = self.rpc.open(callback, callback, callback)
            self.assertEqual(self.requests, [])

            self._process_enqueued_packets()

            callback.assert_has_calls([
                mock.call(call, self.method.response_type(payload='!!!')),
                mock.call(call, self.method.response_type(payload='?')),
                mock.call(call, Status.OK),
            ])

    @mock.patch('pw_rpc.callback_client.call.Call._default_response')
    def test_nonblocking(self, callback) -> None:
        """Tests a bidirectional streaming RPC ended by the server."""
        reply = self.method.response_type(payload='This is the payload!')
        self._enqueue_server_stream(1, self.method, reply)

        self._service.SomeBidiStreaming.invoke()

        callback.assert_called_once_with(mock.ANY, reply)

    def test_nonblocking_server_error(self) -> None:
        rep1 = self.method.response_type(payload='!!!')

        for _ in range(3):
            responses: list = []
            stream = self._service.SomeBidiStreaming.invoke(
                lambda _, res, responses=responses: responses.append(res))
            self.assertFalse(stream.completed())

            self._enqueue_server_stream(1, self.method, rep1)

            stream.send(magic_number=55)
            self.assertFalse(stream.completed())
            self.assertEqual([rep1], responses)

            self._enqueue_error(1, self.method.service, self.method,
                                Status.OUT_OF_RANGE)

            stream.send(magic_number=99999)
            self.assertTrue(stream.completed())
            self.assertEqual([rep1], responses)

            self.assertIsNone(stream.status)
            self.assertIs(Status.OUT_OF_RANGE, stream.error)

            with self.assertRaises(callback_client.RpcError) as context:
                stream.finish_and_wait()
            self.assertIs(context.exception.status, Status.OUT_OF_RANGE)

    def test_nonblocking_server_error_after_stream_end(self) -> None:
        for _ in range(3):
            stream = self._service.SomeBidiStreaming.invoke()

            # Error will be sent in response to the CLIENT_STREAM_END packet.
            self._enqueue_error(1, self.method.service, self.method,
                                Status.INVALID_ARGUMENT)

            with self.assertRaises(callback_client.RpcError) as context:
                stream.finish_and_wait()

            self.assertIs(context.exception.status, Status.INVALID_ARGUMENT)

    def test_nonblocking_send_after_cancelled(self) -> None:
        call = self._service.SomeBidiStreaming.invoke()
        self.assertTrue(call.cancel())

        with self.assertRaises(callback_client.RpcError) as context:
            call.send(payload='hello')

        self.assertIs(context.exception.status, Status.CANCELLED)

    def test_nonblocking_finish_after_completed(self) -> None:
        reply = self.method.response_type(payload='!?')
        self._enqueue_server_stream(1, self.method, reply)
        self._enqueue_response(1, self.method, Status.UNAVAILABLE)

        call = self.rpc.invoke()
        result = call.finish_and_wait()
        self.assertEqual(result.responses, [reply])

        self.assertEqual(result, call.finish_and_wait())
        self.assertEqual(result, call.finish_and_wait())

    def test_nonblocking_finish_after_error(self) -> None:
        reply = self.method.response_type(payload='!?')
        self._enqueue_server_stream(1, self.method, reply)
        self._enqueue_error(1, self.method.service, self.method,
                            Status.UNAVAILABLE)

        call = self.rpc.invoke()

        for _ in range(3):
            with self.assertRaises(callback_client.RpcError) as context:
                call.finish_and_wait()

            self.assertIs(context.exception.status, Status.UNAVAILABLE)
            self.assertIs(call.error, Status.UNAVAILABLE)
            self.assertEqual(call.responses, [reply])

    def test_nonblocking_duplicate_calls_first_is_cancelled(self) -> None:
        first_call = self.rpc.invoke()
        self.assertFalse(first_call.completed())

        second_call = self.rpc.invoke()

        self.assertIs(first_call.error, Status.CANCELLED)
        self.assertFalse(second_call.completed())


if __name__ == '__main__':
    unittest.main()
