blob: 2feb3fd5c69ab1e795470232591a91811d2e0ce3 [file] [log] [blame]
# Copyright 2020 The Pigweed Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""Functions for working with pw_rpc packets."""
import dataclasses
from google.protobuf import message
from pw_status import Status
from pw_rpc.internal import packet_pb2
def decode(data: bytes) -> packet_pb2.RpcPacket:
packet = packet_pb2.RpcPacket()
packet.MergeFromString(data)
return packet
def decode_payload(packet, payload_type):
payload = payload_type()
payload.MergeFromString(packet.payload)
return payload
@dataclasses.dataclass(eq=True, frozen=True)
class RpcIds:
"""Integer IDs that uniquely identify a remote procedure call."""
channel_id: int
service_id: int
method_id: int
call_id: int
def encode_request(rpc: RpcIds, request: message.Message | None) -> bytes:
payload = request.SerializeToString() if request is not None else bytes()
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.REQUEST,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
payload=payload,
).SerializeToString()
def encode_response(rpc: RpcIds, response: message.Message) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.RESPONSE,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
payload=response.SerializeToString(),
).SerializeToString()
def encode_client_stream(rpc: RpcIds, request: message.Message) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_STREAM,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
payload=request.SerializeToString(),
).SerializeToString()
def encode_client_error(packet: packet_pb2.RpcPacket, status: Status) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_ERROR,
channel_id=packet.channel_id,
service_id=packet.service_id,
method_id=packet.method_id,
call_id=packet.call_id,
status=status.value,
).SerializeToString()
def encode_cancel(rpc: RpcIds) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_ERROR,
status=Status.CANCELLED.value,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
).SerializeToString()
def encode_client_stream_end(rpc: RpcIds) -> bytes:
return packet_pb2.RpcPacket(
type=packet_pb2.PacketType.CLIENT_REQUEST_COMPLETION,
channel_id=rpc.channel_id,
service_id=rpc.service_id,
method_id=rpc.method_id,
call_id=rpc.call_id,
).SerializeToString()
def for_server(packet: packet_pb2.RpcPacket) -> bool:
return packet.type % 2 == 0