blob: cb69534a1e1f331984dfbee6b59640efe8f53b62 [file] [log] [blame]
#include "./rpc_fuzzing/rpc_sequence.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message.h"
#include "google/protobuf/util/message_differencer.h"
#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/random/bit_gen_ref.h"
#include "absl/random/random.h"
#include "./fuzztest/internal/logging.h"
namespace fuzztest::internal {
namespace {
std::vector<int> GetShuffledIndices(size_t size, absl::BitGenRef gen) {
std::vector<int> shuffled_indices(size);
std::iota(shuffled_indices.begin(), shuffled_indices.end(), 0);
std::shuffle(shuffled_indices.begin(), shuffled_indices.end(), gen);
return shuffled_indices;
}
std::vector<RpcNodeID> GetShuffledNodeIdsInGraph(const RpcDataFlowGraph& graph,
absl::BitGenRef gen) {
std::vector<RpcNodeID> shuffled_node_ids;
shuffled_node_ids.reserve(graph.NodeNum());
for (const auto& [node_id, unused_node] : graph.GetAllNodes()) {
shuffled_node_ids.push_back(node_id);
}
std::shuffle(shuffled_node_ids.begin(), shuffled_node_ids.end(), gen);
return shuffled_node_ids;
}
void RemapRpcNodeIdToSequenceIndex(
RpcSequence& sequence,
const absl::flat_hash_map<RpcNodeID, RpcNodeID>& id_map) {
for (RpcNode& node : sequence) {
for (RpcDataFlowEdge& edge : node.dependencies()) {
edge.from_node_id = id_map.at(edge.from_node_id);
}
}
}
void CollectTopologicalSortingOrder(const RpcDataFlowGraph& graph,
RpcNodeID current_id, absl::BitGenRef gen,
absl::flat_hash_set<RpcNodeID>& visited,
std::vector<RpcNodeID>& sorted_ids) {
if (visited.contains(current_id)) return;
visited.insert(current_id);
const RpcNode& node = graph.GetNode(current_id);
std::vector<RpcNodeID> shuffled_edge_indices =
GetShuffledIndices(node.dependencies().size(), gen);
for (int edge_index : shuffled_edge_indices) {
CollectTopologicalSortingOrder(graph,
node.dependencies()[edge_index].from_node_id,
gen, visited, sorted_ids);
}
sorted_ids.push_back(current_id);
}
std::vector<RpcNodeID> GetRandomTopologicalOrdering(
const RpcDataFlowGraph& graph) {
std::vector<RpcNodeID> sorted_node_ids;
sorted_node_ids.reserve(graph.NodeNum());
absl::flat_hash_set<RpcNodeID> visited;
absl::BitGen gen;
std::vector<RpcNodeID> shuffled_node_ids =
GetShuffledNodeIdsInGraph(graph, gen);
for (RpcNodeID node_id : shuffled_node_ids) {
CollectTopologicalSortingOrder(graph, node_id, gen, visited,
sorted_node_ids);
}
FUZZTEST_INTERNAL_CHECK(
sorted_node_ids.size() == graph.NodeNum(),
"Topological sort results in different number of nodes!");
return sorted_node_ids;
}
RpcSequence SequentializeGraphByOrder(
const RpcDataFlowGraph& graph, const std::vector<RpcNodeID>& node_order) {
RpcSequence result;
result.reserve(graph.NodeNum());
absl::flat_hash_map<RpcNodeID /* original node id */,
RpcNodeID /* sequence index */>
id_map;
for (size_t seq_idx = 0; seq_idx < node_order.size(); ++seq_idx) {
RpcNodeID original_node_id = node_order[seq_idx];
id_map[original_node_id] = seq_idx;
result.push_back(graph.GetNode(original_node_id));
}
RemapRpcNodeIdToSequenceIndex(result, id_map);
return result;
}
} // namespace
RpcNode::RpcNode(const RpcNode& other)
: method_(other.method_),
request_(absl::WrapUnique(other.request_->New())),
dependencies_(other.dependencies_) {
request_->CopyFrom(*other.request_);
}
RpcNode& RpcNode::operator=(const RpcNode& other) {
method_ = other.method_;
request_ = absl::WrapUnique(other.request_->New());
request_->CopyFrom(*other.request_);
dependencies_ = other.dependencies_;
return *this;
}
bool operator==(const RpcDataFlowEdge& lhs, const RpcDataFlowEdge& rhs) {
return lhs.from_node_id == rhs.from_node_id &&
lhs.from_field == rhs.from_field && lhs.to_field == rhs.to_field;
}
bool operator!=(const RpcDataFlowEdge& lhs, const RpcDataFlowEdge& rhs) {
return !(lhs == rhs);
}
bool operator==(const RpcNode& lhs, const RpcNode& rhs) {
return lhs.method_->full_name() == rhs.method_->full_name() &&
google::protobuf::util::MessageDifferencer::Equals(*lhs.request_,
*rhs.request_) &&
lhs.dependencies_ == rhs.dependencies_;
}
bool operator!=(const RpcNode& lhs, const RpcNode& rhs) {
return !(lhs == rhs);
}
void RpcDataFlowGraph::AddNode(RpcNodeID id, RpcNode rpc_node) {
InvalidateNodeOrder();
rpc_nodes_.emplace(id, rpc_node);
}
void RpcDataFlowGraph::RemoveNode(RpcNodeID id) {
InvalidateNodeOrder();
rpc_nodes_.erase(id);
}
const RpcNode& RpcDataFlowGraph::GetNode(RpcNodeID id) const {
FUZZTEST_INTERNAL_CHECK(rpc_nodes_.contains(id), "Invalid id!");
return rpc_nodes_.at(id);
}
size_t RpcDataFlowGraph::NodeNum() const { return rpc_nodes_.size(); }
RpcDataFlowGraph RpcDataFlowGraph::FromSequence(const RpcSequence& sequence) {
RpcDataFlowGraph graph;
for (int node_idx = 0; node_idx < sequence.size(); ++node_idx) {
graph.AddNode(node_idx, sequence[node_idx]);
}
graph.node_order_ = std::vector<RpcNodeID>(sequence.size());
std::iota(graph.node_order_->begin(), graph.node_order_->end(), 0);
return graph;
}
RpcSequence RpcDataFlowGraph::GetSequence() const {
if (!node_order_.has_value()) {
node_order_ = GetRandomTopologicalOrdering(*this);
}
return SequentializeGraphByOrder(*this, *node_order_);
}
std::vector<RpcNodeID>& RpcDataFlowGraph::GetOrderedNodeIds() const {
if (!node_order_.has_value()) {
node_order_ = GetRandomTopologicalOrdering(*this);
}
return *node_order_;
}
bool operator==(const RpcDataFlowGraph& lhs, const RpcDataFlowGraph& rhs) {
return lhs.GetSequence() == rhs.GetSequence();
}
bool operator!=(const RpcDataFlowGraph& lhs, const RpcDataFlowGraph& rhs) {
return !(lhs == rhs);
}
void RpcDataFlowGraph::RandomizeTopologicalOrdering() { InvalidateNodeOrder(); }
} // namespace fuzztest::internal