| # Copyright 2020 The Pigweed Authors |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); you may not |
| # use this file except in compliance with the License. You may obtain a copy of |
| # the License at |
| # |
| # https://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| # License for the specific language governing permissions and limitations under |
| # the License. |
| """Functions for working with pw_rpc packets.""" |
| |
| from typing import Optional |
| |
| from google.protobuf import message |
| from pw_status import Status |
| |
| from pw_rpc.internal import packet_pb2 |
| |
| |
| def decode(data: bytes) -> packet_pb2.RpcPacket: |
| packet = packet_pb2.RpcPacket() |
| packet.MergeFromString(data) |
| return packet |
| |
| |
| def decode_payload(packet, payload_type): |
| payload = payload_type() |
| payload.MergeFromString(packet.payload) |
| return payload |
| |
| |
| def _ids(rpc: tuple) -> tuple: |
| return tuple(item if isinstance(item, int) else item.id for item in rpc) |
| |
| |
| def encode_request(rpc: tuple, request: Optional[message.Message]) -> bytes: |
| channel, service, method = _ids(rpc) |
| payload = request.SerializeToString() if request is not None else bytes() |
| |
| return packet_pb2.RpcPacket(type=packet_pb2.PacketType.REQUEST, |
| channel_id=channel, |
| service_id=service, |
| method_id=method, |
| payload=payload).SerializeToString() |
| |
| |
| def encode_response(rpc: tuple, response: message.Message) -> bytes: |
| channel, service, method = _ids(rpc) |
| |
| return packet_pb2.RpcPacket( |
| type=packet_pb2.PacketType.RESPONSE, |
| channel_id=channel, |
| service_id=service, |
| method_id=method, |
| payload=response.SerializeToString()).SerializeToString() |
| |
| |
| def encode_client_stream(rpc: tuple, request: message.Message) -> bytes: |
| channel, service, method = _ids(rpc) |
| |
| return packet_pb2.RpcPacket( |
| type=packet_pb2.PacketType.CLIENT_STREAM, |
| channel_id=channel, |
| service_id=service, |
| method_id=method, |
| payload=request.SerializeToString()).SerializeToString() |
| |
| |
| def encode_client_error(packet: packet_pb2.RpcPacket, status: Status) -> bytes: |
| return packet_pb2.RpcPacket(type=packet_pb2.PacketType.CLIENT_ERROR, |
| channel_id=packet.channel_id, |
| service_id=packet.service_id, |
| method_id=packet.method_id, |
| status=status.value).SerializeToString() |
| |
| |
| def encode_cancel(rpc: tuple) -> bytes: |
| channel, service, method = _ids(rpc) |
| return packet_pb2.RpcPacket(type=packet_pb2.PacketType.CLIENT_ERROR, |
| status=Status.CANCELLED.value, |
| channel_id=channel, |
| service_id=service, |
| method_id=method).SerializeToString() |
| |
| |
| def encode_client_stream_end(rpc: tuple) -> bytes: |
| channel, service, method = _ids(rpc) |
| |
| return packet_pb2.RpcPacket(type=packet_pb2.PacketType.CLIENT_STREAM_END, |
| channel_id=channel, |
| service_id=service, |
| method_id=method).SerializeToString() |
| |
| |
| def for_server(packet: packet_pb2.RpcPacket) -> bool: |
| return packet.type % 2 == 0 |