blob: 13e7df1a256253365f4ecea0fe7fc62a6e20ed9b [file] [edit]
// Copyright 2022 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_REGEXP_DFA_H_
#define FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_REGEXP_DFA_H_
#include <algorithm>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <optional>
#include <string>
#include <vector>
#include "absl/random/bit_gen_ref.h"
#include "absl/random/discrete_distribution.h"
#include "absl/random/distributions.h"
#include "absl/strings/string_view.h"
#include "./fuzztest/internal/logging.h"
#include "re2/re2.h"
namespace fuzztest::internal {
// Represents the deterministic finite automaton (DFA) of a regular expression.
class RegexpDFA {
public:
struct State {
bool is_end_state() const { return next.empty(); }
// The special character `256` (kEndOfString) indicates the end of the input
// string.
struct StateTransition {
std::vector<std::int16_t> chars_to_match;
int next_state_id;
};
std::vector<StateTransition> next;
// The weight reprensents the probablity of an edge being chosen during
// random walk. The larger the weight, the higher chance the edge gets
// picked. An edge transitioning to a node that is more likely to reach the
// end state has larger weight than that doesn't.
absl::discrete_distribution<int> edge_weight_distribution;
};
// `edge_index` is the index for the outgoing edge in State::next
struct Edge {
int from_state_id;
int edge_index;
};
static RegexpDFA Create(absl::string_view regexp);
std::string GenerateString(absl::BitGenRef prng) {
std::string result;
State* state = &states_[0];
while (true) {
FUZZTEST_INTERNAL_CHECK(!state->next.empty(), "Empty next state!");
// Pick a random next state by weight.
int rand_index = state->edge_weight_distribution(prng);
auto& [next_string, next_id] = state->next[rand_index];
state = &states_[next_id];
result.insert(result.end(), next_string.begin(), next_string.end());
if (next_string.back() == kEndOfString) {
FUZZTEST_INTERNAL_CHECK(state->is_end_state(),
"EOS should lead to end state!");
result.pop_back();
break;
}
}
return result;
}
// Randomly walk from the state of `from_state_id` to any state of
// `to_state_ids` or an end state.
std::vector<Edge> FindPath(
absl::BitGenRef prng, int from_state_id,
const std::vector<std::optional<int>>& to_state_ids) {
std::vector<Edge> path;
int cur_state_id = from_state_id;
State* cur_state = &states_[cur_state_id];
while (true) {
FUZZTEST_INTERNAL_CHECK(!cur_state->next.empty(), "Empty next state!");
// Pick a random next state.
int offset = cur_state->edge_weight_distribution(prng);
auto& [next_char, next_state_id] = cur_state->next[offset];
path.push_back({cur_state_id, offset});
cur_state_id = next_state_id;
cur_state = &states_[cur_state_id];
// Reached an end state or found a state in the original path?
if (cur_state->is_end_state() ||
to_state_ids[next_state_id].has_value()) {
path.push_back({next_state_id, -1});
break;
}
}
return path;
}
// Randomly DFS from the state of `from_state_id` to the state of
// `to_state_id`, trying to find a path of length less than, if no, equal to
// `length`. If we have multiple such paths, randomly pick one of them. Since
// this (nearly) fully exlpores all the paths, there is no big difference in
// the efficiency. And we prefer DFS to BFS for better readability.
std::vector<Edge> FindPathWithinLengthDFS(absl::BitGenRef prng,
int from_state_id, int to_state_id,
int length) {
// Each state maintains an edge and a counter for each possible length. The
// edge is the last edge in the path from `from_state` to the state and can
// be used to reconstruct the path. And the counter is the number of paths
// to the state, which can be used for reservoir sampling.
struct LastEdgeAndCounter {
std::optional<Edge> edge;
int counter;
};
std::vector<std::vector<LastEdgeAndCounter>> last_edges_and_counters(
states_.size(), std::vector<LastEdgeAndCounter>(length + 1));
// Randomness for DFS. Instead of always starting to explore from edge index
// 0, we start with a different random offset for each state.
std::vector<int> rand_edge_offsets(states_.size(), 0);
for (int& i : rand_edge_offsets) i = absl::Uniform<int>(prng, 0u, 256);
std::vector<Edge> stack{Edge{from_state_id, 0}};
do {
auto [current_state_id, edge_index] = stack.back();
if (edge_index == states_[current_state_id].next.size()) {
stack.pop_back();
continue;
}
++stack.back().edge_index;
const int current_path_length = static_cast<int>(stack.size());
const int real_edge_index =
(edge_index + rand_edge_offsets[current_state_id]) %
static_cast<int>(states_[current_state_id].next.size());
const int next_state_id =
states_[current_state_id].next[real_edge_index].next_state_id;
const int n_path_of_current_length =
++last_edges_and_counters[next_state_id][current_path_length].counter;
// Reservoir Sampling.
if (absl::Bernoulli(prng, 1.0 / n_path_of_current_length)) {
last_edges_and_counters[next_state_id][current_path_length].edge =
Edge{current_state_id, real_edge_index};
}
if (n_path_of_current_length == 1 && current_path_length != length) {
stack.push_back(Edge{next_state_id, 0});
}
} while (!stack.empty());
std::vector<int> candidate_lens;
for (int len = 1; len <= length; ++len) {
if (last_edges_and_counters[to_state_id][len].edge.has_value()) {
candidate_lens.push_back(len);
}
}
FUZZTEST_INTERNAL_CHECK(!candidate_lens.empty(), "Cannot find a path!");
int state_id = to_state_id;
std::vector<Edge> result;
for (int len =
candidate_lens[absl::Uniform<int>(prng, 0, candidate_lens.size())];
len > 0; --len) {
result.push_back(*(last_edges_and_counters[state_id][len].edge));
state_id = last_edges_and_counters[state_id][len].edge->from_state_id;
}
FUZZTEST_INTERNAL_CHECK(state_id == from_state_id,
"Cannot find a path from from_state");
std::reverse(result.begin(), result.end());
return result;
}
std::optional<std::vector<Edge>> StringToDFAPath(absl::string_view s) const;
std::optional<std::string> DFAPathToString(
const std::vector<Edge>& path, size_t start_offset = 0,
std::optional<size_t> end = std::nullopt) const;
size_t state_count() const { return states_.size(); }
int end_state_id() const { return end_state_id_; }
private:
RegexpDFA() {}
// Given a state and the next input character, try to match the character and
// return the index in State::next. Return `nullopt` if the matching fails.
std::optional<int> NextState(const State& cur_state,
const std::vector<std::int16_t>& input_chars,
size_t& cur_index) const;
static std::unique_ptr<re2::Prog> CompileRegexp(absl::string_view regexp);
void BuildEntireDFA(std::unique_ptr<re2::Prog> compiled_regexp);
// Assign weights (the probability of being picked during random walk)
// for edges of the DFA so that very long strings are less likely. All the
// edge weights of a state sums to 1. A node is "safe" if it has an high
// chance (currently at least 50%, defined by `kProbToSafeNode`) to reach
// closer to the end state. We separate the edges into safe and unsafe ones.
// The safe edges transition to safe nodes. First we mark every end
// states/nodes as safe nodes. Next we start BFS from the safe nodes. For each
// node to be handled during the exploration, we assign
// `kProbToSafeNode/num_safe_edges` to each safe edge and `(1 -
// kProbToSafeNode)/num_of_unsafe_edges` to the unsafe ones. After that we can
// mark the node as safe because it has now at least 50% chance to go to
// another safe node, which is closer to the end nodes.
void ComputeEdgeWeights();
// Compress the DFA so that every state except the end states have at least
// two outgoing states. With this condition, every non-ending states are good
// candidates for mutation.
void CompressStates();
// We need a special character representing "end of string". This is necessary
// to make sure that we have exact matches: i.e., that we always reach end
// states with the "end of string".
static constexpr std::int16_t kEndOfString = 256;
std::vector<State> states_;
int end_state_id_;
};
} // namespace fuzztest::internal
#endif // FUZZTEST_FUZZTEST_INTERNAL_DOMAINS_REGEXP_DFA_H_