blob: e77db5fedacea4e35c96948d7438d8be1c62b539 [file] [log] [blame]
// Copyright 2020 The Pigweed Authors
//
// Licensed under the Apache License, Version 2.0 (the "License"); you may not
// use this file except in compliance with the License. You may obtain a copy of
// the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
// License for the specific language governing permissions and limitations under
// the License.
#pragma once
#include <new>
#include "pw_bytes/span.h"
#include "pw_function/function.h"
#include "pw_rpc/internal/base_client_call.h"
#include "pw_rpc/internal/method_type.h"
#include "pw_rpc/internal/nanopb_common.h"
#include "pw_status/status.h"
namespace pw::rpc {
// Response handler callback for unary RPC methods.
template <typename Response>
class UnaryResponseHandler {
public:
virtual ~UnaryResponseHandler() = default;
// Called when the response is received from the server with the method's
// status and the deserialized response struct.
virtual void ReceivedResponse(Status status, const Response& response) = 0;
// Called when an error occurs internally in the RPC client or server.
virtual void RpcError(Status) {}
};
// Response handler callbacks for server streaming RPC methods.
template <typename Response>
class ServerStreamingResponseHandler {
public:
virtual ~ServerStreamingResponseHandler() = default;
// Called on every response received from the server with the deserialized
// response struct.
virtual void ReceivedResponse(const Response& response) = 0;
// Called when the server ends the stream with the overall RPC status.
virtual void Complete(Status status) = 0;
// Called when an error occurs internally in the RPC client or server.
virtual void RpcError(Status) {}
};
namespace internal {
// Non-templated nanopb base class providing protobuf encoding and decoding.
class BaseNanopbClientCall : public BaseClientCall {
public:
Status SendRequest(const void* request_struct);
protected:
constexpr BaseNanopbClientCall(rpc::Channel* channel,
uint32_t service_id,
uint32_t method_id,
ResponseHandler handler,
NanopbMessageDescriptor request_fields,
NanopbMessageDescriptor response_fields)
: BaseClientCall(channel, service_id, method_id, handler),
serde_(request_fields, response_fields) {}
constexpr BaseNanopbClientCall()
: BaseClientCall(), serde_(nullptr, nullptr) {}
BaseNanopbClientCall(const BaseNanopbClientCall&) = delete;
BaseNanopbClientCall& operator=(const BaseNanopbClientCall&) = delete;
BaseNanopbClientCall(BaseNanopbClientCall&&) = default;
BaseNanopbClientCall& operator=(BaseNanopbClientCall&&) = default;
constexpr const NanopbMethodSerde& serde() const { return serde_; }
private:
NanopbMethodSerde serde_;
};
struct ErrorCallbacks {
ErrorCallbacks() = default;
ErrorCallbacks(Function<void(Status)> error) : rpc_error(std::move(error)) {}
void InvokeRpcError(Status status) {
if (rpc_error != nullptr) {
rpc_error(status);
}
}
Function<void(Status)> rpc_error;
};
template <typename ResponseType>
struct UnaryCallbacks : public ErrorCallbacks {
using Response = ResponseType;
static constexpr MethodType kType = MethodType::kUnary;
UnaryCallbacks() = default;
UnaryCallbacks(Function<void(const Response&, Status)> response,
Function<void(Status)> error)
: ErrorCallbacks(std::move(error)), unary_response(std::move(response)) {}
Function<void(const Response&, Status)> unary_response;
};
template <typename ResponseType>
struct ServerStreamingCallbacks : public ErrorCallbacks {
using Response = ResponseType;
static constexpr MethodType kType = MethodType::kServerStreaming;
ServerStreamingCallbacks() = default;
ServerStreamingCallbacks(Function<void(const Response&)> response,
Function<void(Status)> end,
Function<void(Status)> error)
: ErrorCallbacks(std::move(error)),
stream_response(std::move(response)),
stream_end(std::move(end)) {}
Function<void(const Response&)> stream_response;
Function<void(Status)> stream_end;
};
} // namespace internal
template <typename Callbacks>
class NanopbClientCall : public internal::BaseNanopbClientCall {
public:
constexpr NanopbClientCall(Channel* channel,
uint32_t service_id,
uint32_t method_id,
Callbacks callbacks,
internal::NanopbMessageDescriptor request_fields,
internal::NanopbMessageDescriptor response_fields)
: BaseNanopbClientCall(channel,
service_id,
method_id,
&ResponseHandler,
request_fields,
response_fields),
callbacks_(std::move(callbacks)) {}
constexpr NanopbClientCall() : BaseNanopbClientCall(), callbacks_({}) {}
NanopbClientCall(const NanopbClientCall&) = delete;
NanopbClientCall& operator=(const NanopbClientCall&) = delete;
NanopbClientCall(NanopbClientCall&&) = default;
NanopbClientCall& operator=(NanopbClientCall&&) = default;
private:
using Response = typename Callbacks::Response;
// Buffer into which the nanopb struct is decoded. Its contents are unknown,
// so it is aligned to maximum alignment to accommodate any type.
using ResponseBuffer =
std::aligned_storage_t<sizeof(Response), alignof(std::max_align_t)>;
friend class Client;
static void ResponseHandler(internal::BaseClientCall& call,
const internal::Packet& packet) {
static_cast<NanopbClientCall<Callbacks>&>(call).HandleResponse(packet);
}
void HandleResponse(const internal::Packet& packet) {
if constexpr (Callbacks::kType == internal::MethodType::kUnary) {
InvokeUnaryCallback(packet);
}
if constexpr (Callbacks::kType == internal::MethodType::kServerStreaming) {
InvokeServerStreamingCallback(packet);
}
}
void InvokeUnaryCallback(const internal::Packet& packet) {
if (packet.type() == internal::PacketType::SERVER_ERROR) {
callbacks_.InvokeRpcError(packet.status());
return;
}
ResponseBuffer response_struct{};
if (callbacks_.unary_response &&
serde().DecodeResponse(&response_struct, packet.payload())) {
callbacks_.unary_response(
*std::launder(reinterpret_cast<Response*>(&response_struct)),
packet.status());
} else {
callbacks_.InvokeRpcError(Status::DataLoss());
}
}
void InvokeServerStreamingCallback(const internal::Packet& packet) {
if (packet.type() == internal::PacketType::SERVER_ERROR) {
callbacks_.InvokeRpcError(packet.status());
return;
}
if (packet.type() == internal::PacketType::RESPONSE) {
// The server sent the last packet; call the stream end callback.
if (callbacks_.stream_end) {
callbacks_.stream_end(packet.status());
}
return;
}
// Handle internal::PacketType::SERVER_STREAM packets.
ResponseBuffer response_struct{};
if (callbacks_.stream_response &&
serde().DecodeResponse(&response_struct, packet.payload())) {
callbacks_.stream_response(
*std::launder(reinterpret_cast<Response*>(&response_struct)));
} else {
callbacks_.InvokeRpcError(Status::DataLoss());
}
}
Callbacks callbacks_;
};
} // namespace pw::rpc