| // Copyright 2022 Google LLC |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| #include "./grammar_codegen/backend.h" |
| |
| #include <cctype> |
| #include <random> |
| #include <string> |
| #include <string_view> |
| #include <vector> |
| |
| #include "absl/strings/ascii.h" |
| #include "absl/strings/str_format.h" |
| #include "absl/strings/str_join.h" |
| #include "./grammar_codegen/grammar_info.h" |
| #include "./fuzztest/internal/logging.h" |
| |
| namespace fuzztest::internal::grammar { |
| |
| namespace { |
| |
| std::vector<GrammarRule> SimplifyProductionWithFallbackIndex( |
| ProductionWithFallbackIndex& productions, std::string_view symbol_name); |
| |
| void SwitchBlockToNewNonTerminal(Block& block, std::string symbol_name) { |
| block.element.emplace<NonTerminal>(NonTerminal{symbol_name}); |
| } |
| |
| std::string CreateIRNodeNameForClass(std::string_view symbol_name) { |
| static int counter = 0; |
| FUZZTEST_INTERNAL_CHECK( |
| symbol_name.size() > 4 && |
| symbol_name.substr(symbol_name.size() - 4) == "Node", |
| std::string("Not a valid symbol class name: ") + " " + |
| std::string(symbol_name)); |
| return absl::StrFormat( |
| "%sSubNode%d", symbol_name.substr(0, symbol_name.size() - 4), counter++); |
| } |
| |
| // Create a IR node for a vector-like block. |
| // For example, A: B | C*, we convert it into A: B | N, N: C*. So that A can be |
| // reprensented with a Variant and N can be represented with a Vector. |
| GrammarRule SimplifyVectorLikeBlock(Block& block, |
| std::string_view parent_symbol_name) { |
| ProductionRule prod_rule{{block}}; |
| block.range = Range::kNoRange; |
| std::string new_symbol_name = CreateIRNodeNameForClass(parent_symbol_name); |
| SwitchBlockToNewNonTerminal(block, new_symbol_name); |
| return GrammarRule{new_symbol_name, |
| ProductionWithFallbackIndex{0, {std::move(prod_rule)}}}; |
| } |
| |
| // Create a IR node for a tuple-like production. |
| // For example, A: B | C D, we convert it into A: B | N, N: C D. So that A can |
| // be reprensented with a Variant and N can be represented with a Tuple. |
| GrammarRule SimplifyTupleLikeProduction(ProductionRule& production, |
| std::string_view parent_symbol_name) { |
| std::string new_symbol_name = CreateIRNodeNameForClass(parent_symbol_name); |
| Block block{Range::kNoRange, NonTerminal{new_symbol_name}}; |
| GrammarRule result = GrammarRule{ |
| new_symbol_name, ProductionWithFallbackIndex{0, {production}}}; |
| production = {{block}}; |
| return result; |
| } |
| |
| // Create a IR node for a variant-like block. |
| // For example, A: B | (C | D), we convert it into A: B | N, N: C | D. So that |
| // A can be reprensented with a Variant and N can be represented with a Variant. |
| // |
| // Note: Although such a definition seems meaningless, but it is allowed by |
| // Antlr4. |
| std::vector<GrammarRule> SimplifyVariantLikeBlock( |
| Block& block, std::string_view parent_symbol_name) { |
| ProductionWithFallbackIndex inner = |
| std::get<ProductionWithFallbackIndex>(block.element); |
| std::string new_symbol_name = CreateIRNodeNameForClass(parent_symbol_name); |
| SwitchBlockToNewNonTerminal(block, new_symbol_name); |
| auto new_grammar_rules = |
| SimplifyProductionWithFallbackIndex(inner, parent_symbol_name); |
| new_grammar_rules.emplace_back(GrammarRule{new_symbol_name, inner}); |
| return new_grammar_rules; |
| } |
| |
| std::vector<GrammarRule> SimplifyProductionWithFallbackIndex( |
| ProductionWithFallbackIndex& productions, std::string_view symbol_name) { |
| std::vector<GrammarRule> intermediate_grammar_rules; |
| for (ProductionRule& production : productions.production_rules) { |
| for (Block& block : production.blocks) { |
| if (block.range != Range::kNoRange) { |
| intermediate_grammar_rules.emplace_back( |
| SimplifyVectorLikeBlock(block, symbol_name)); |
| } |
| if (block.element.index() == BlockType::kSubProductions) { |
| auto new_grammar_rules = SimplifyVariantLikeBlock(block, symbol_name); |
| intermediate_grammar_rules.insert(intermediate_grammar_rules.end(), |
| new_grammar_rules.begin(), |
| new_grammar_rules.end()); |
| } |
| } |
| if (productions.production_rules.size() > 1 && |
| production.blocks.size() > 1) { |
| intermediate_grammar_rules.emplace_back( |
| SimplifyTupleLikeProduction(production, symbol_name)); |
| } |
| } |
| return intermediate_grammar_rules; |
| } |
| |
| std::string WrapChildTypeWithRangedVector(std::string_view parent_type, |
| std::string_view child_type, |
| const Range range) { |
| switch (range) { |
| case Range::kNoRange: |
| return std::string(child_type); |
| case Range::kUnlimited: |
| return absl::StrFormat("Vector<k%s, %s>", parent_type, child_type); |
| case Range::kNonEmpty: |
| return absl::StrFormat("NonEmptyVector<k%s, %s>", parent_type, |
| child_type); |
| case Range::kOptional: |
| return absl::StrFormat("Optional<k%s, %s>", parent_type, child_type); |
| } |
| } |
| } // namespace |
| |
| void CodeGenerator::Preprocess(Grammar& grammar) { |
| std::vector<GrammarRule> new_grammar_rules; |
| for (GrammarRule& rule : grammar.rules) { |
| auto new_ir_rules = SimplifyProductionWithFallbackIndex( |
| rule.productions, GetClassNameForSymbol(rule.symbol_name)); |
| new_grammar_rules.insert(new_grammar_rules.end(), new_ir_rules.begin(), |
| new_ir_rules.end()); |
| } |
| grammar.rules.insert(grammar.rules.end(), new_grammar_rules.begin(), |
| new_grammar_rules.end()); |
| } |
| |
| std::string CodeGenerator::Generate() { |
| Preprocess(grammar_); |
| constexpr std::string_view kCodeTemplate = |
| "#ifndef FUZZTEST_GRAMMARS_%1$s_GRAMMAR_H_\n" |
| "#define " |
| "FUZZTEST_GRAMMARS_%1$s_GRAMMAR_H_\n\n" |
| "#include " |
| "\"./fuzztest/internal/domains/in_grammar_impl.h\"\n\n" |
| "namespace fuzztest::internal::grammar::%2$s {\n\n" |
| "%3$s" |
| "} // namespace fuzztest::internal::grammar::%2$s\n" |
| "namespace fuzztest::internal_no_adl{\n\n" |
| "inline auto In%4$sGrammar() {" |
| "return " |
| "internal::grammar::InGrammarImpl<internal::grammar::%2$s::%4$sNode>();" |
| "}\n\n" |
| "} // namespace fuzztest::internal_no_adl\n" |
| "#endif // " |
| "FUZZTEST_GRAMMARS_%1$s_GRAMMAR_H_"; |
| |
| CalculateFallBackIndex(grammar_.rules); |
| std::string generated_code; |
| |
| std::string class_definitions; |
| for (GrammarRule& rule : grammar_.rules) { |
| absl::StrAppend(&class_definitions, BuildClassDefinitionForSymbol(rule)); |
| } |
| |
| for (auto& [literal, class_name] : literal_node_ids_) { |
| absl::StrAppend(&class_definitions, |
| BuildClassDefinitionForLiteral(class_name)); |
| } |
| |
| for (auto& [charset, class_name] : charset_node_ids_) { |
| absl::StrAppend(&class_definitions, |
| BuilldClassDefinitionForCharSet(class_name)); |
| } |
| |
| // The literals and charsets are collected when we build the definitions. So |
| // the forward declaration has to be built after the definitions are built. |
| std::string enum_for_ast_types; |
| std::string forward_declaration; |
| std::string string_literal_definitions; |
| for (GrammarRule& rule : grammar_.rules) { |
| absl::StrAppend(&forward_declaration, "class ", |
| GetClassNameForSymbol(rule.symbol_name), ";"); |
| absl::StrAppendFormat(&enum_for_ast_types, "k%s,", |
| GetClassNameForSymbol(rule.symbol_name)); |
| } |
| for (auto& [content, class_name] : literal_node_ids_) { |
| absl::StrAppend(&forward_declaration, "class ", class_name, ";"); |
| absl::StrAppendFormat(&string_literal_definitions, |
| "inline constexpr absl::string_view kStr%s = %s;", |
| class_name, content); |
| absl::StrAppendFormat(&enum_for_ast_types, "k%s,", class_name); |
| } |
| for (auto& [content, class_name] : charset_node_ids_) { |
| absl::StrAppend(&forward_declaration, "class ", class_name, ";"); |
| absl::StrAppendFormat(&string_literal_definitions, |
| "inline constexpr absl::string_view kStr%s = \"%s\";", |
| class_name, content); |
| absl::StrAppendFormat(&enum_for_ast_types, "k%s,", class_name); |
| } |
| |
| std::string captilialized_grammar_name = grammar_.grammar_name; |
| captilialized_grammar_name[0] = toupper(captilialized_grammar_name[0]); |
| |
| enum_for_ast_types = absl::StrFormat( |
| "enum %sTypes {%s};", captilialized_grammar_name, enum_for_ast_types); |
| |
| absl::StrAppend(&generated_code, enum_for_ast_types, forward_declaration, |
| "\n\n", string_literal_definitions, "\n\n", |
| class_definitions); |
| |
| std::string upper_grammar_name = absl::AsciiStrToUpper(grammar_.grammar_name); |
| return absl::StrFormat(kCodeTemplate, upper_grammar_name, |
| grammar_.grammar_name, generated_code, |
| captilialized_grammar_name); |
| } |
| |
| std::string CodeGenerator::BuildBaseTypeForGrammarRule( |
| const GrammarRule& rule) { |
| std::string class_name = GetClassNameForSymbol(rule.symbol_name); |
| const std::vector<ProductionRule>& prod_rules = |
| rule.productions.production_rules; |
| FUZZTEST_INTERNAL_CHECK(!prod_rules.empty(), "No expansion!"); |
| if (prod_rules.size() > 1) { |
| // This is a variant. |
| std::vector<std::string> production_child_types; |
| for (const ProductionRule& prod_rule : prod_rules) { |
| FUZZTEST_INTERNAL_CHECK(prod_rule.blocks.size() == 1, |
| "Incorrect preprocess."); |
| auto block = prod_rule.blocks[0]; |
| FUZZTEST_INTERNAL_CHECK( |
| block.range == Range::kNoRange && |
| block.element.index() != BlockType::kSubProductions, |
| "Incorrect preprocess."); |
| production_child_types.push_back(GetClassName(block)); |
| } |
| return absl::StrFormat("VariantDomain<k%s, %d, %s>", class_name, |
| *rule.productions.fallback_index, |
| absl::StrJoin(production_child_types, ",")); |
| } else if (prod_rules[0].blocks.size() == 1 && |
| prod_rules[0].blocks[0].range != Range::kNoRange) { |
| // This is a vector. |
| auto block = prod_rules[0].blocks[0]; |
| return WrapChildTypeWithRangedVector(class_name, GetClassName(block), |
| block.range); |
| } else { |
| // This is a tuple. |
| std::vector<std::string> production_child_types; |
| auto blocks = prod_rules[0].blocks; |
| for (const auto& block : blocks) { |
| FUZZTEST_INTERNAL_CHECK( |
| block.range == Range::kNoRange && |
| block.element.index() != BlockType::kSubProductions, |
| "Incorrect preprocess."); |
| production_child_types.push_back(GetClassName(block)); |
| } |
| return absl::StrFormat("TupleDomain<k%s, %s>", class_name, |
| absl::StrJoin(production_child_types, ",")); |
| } |
| } |
| |
| std::string CodeGenerator::BuildClassDefinitionForSymbol(GrammarRule& rule) { |
| return absl::StrFormat("class %s final : public %s {};\n", |
| GetClassNameForSymbol(rule.symbol_name), |
| BuildBaseTypeForGrammarRule(rule)); |
| } |
| |
| std::string CodeGenerator::BuilldClassDefinitionForCharSet( |
| std::string_view class_name) { |
| return absl::StrFormat( |
| "class %s final: public RegexLiteralDomain<k%s, kStr%s> {};", class_name, |
| class_name, class_name); |
| } |
| |
| std::string CodeGenerator::BuildClassDefinitionForLiteral( |
| std::string_view class_name) { |
| return absl::StrFormat( |
| "class %s final: public StringLiteralDomain<k%s, kStr%s>{};", class_name, |
| class_name, class_name); |
| } |
| |
| // Caculate the fallback indexes for all the symbols (including |
| // sub-productions). |
| bool CodeGenerator::IsSymbolSafe(std::string_view symbol) { |
| return safe_rules_.find(symbol) != safe_rules_.end(); |
| } |
| |
| void CodeGenerator::MarkSymbolAsSafe(std::string_view symbol) { |
| safe_rules_.insert(std::string(symbol)); |
| } |
| |
| bool CodeGenerator::TryMarkProductionRuleVecAsSafe( |
| ProductionWithFallbackIndex& productions) { |
| std::vector<size_t> index_of_safe_productions; |
| for (size_t i = 0; i < productions.production_rules.size(); ++i) { |
| if (TryMarkProductionRuleAsSafe(productions.production_rules[i])) |
| index_of_safe_productions.push_back(i); |
| } |
| if (index_of_safe_productions.empty()) return false; |
| if (!productions.fallback_index.has_value()) |
| productions.fallback_index = index_of_safe_productions[0]; |
| return true; |
| } |
| |
| // A range is safe if it allows the symbol to generate nothing. |
| bool CodeGenerator::HasSafeRange(const Block& block) { |
| return block.range == Range::kOptional || block.range == Range::kUnlimited; |
| } |
| |
| bool CodeGenerator::TryMarkBlockAsSafe(Block& block) { |
| std::variant<Terminal, NonTerminal, ProductionWithFallbackIndex>& element = |
| block.element; |
| switch (element.index()) { |
| case BlockType::kTerminal: |
| return true; |
| case BlockType::kNonTerminal: |
| return IsSymbolSafe(std::get<BlockType::kNonTerminal>(element).name) || |
| HasSafeRange(block); |
| case BlockType::kSubProductions: { |
| bool is_safe = TryMarkProductionRuleVecAsSafe( |
| std::get<BlockType::kSubProductions>(element)); |
| if (!is_safe && HasSafeRange(block)) { |
| std::get<BlockType::kSubProductions>(element).fallback_index = 0; |
| is_safe = true; |
| } |
| return is_safe; |
| } |
| default: |
| FUZZTEST_INTERNAL_CHECK(false, "The execution should never reach here!"); |
| } |
| } |
| |
| bool CodeGenerator::TryMarkProductionRuleAsSafe(ProductionRule& prod_rule) { |
| for (Block& block : prod_rule.blocks) { |
| if (!TryMarkBlockAsSafe(block)) return false; |
| } |
| return true; |
| } |
| |
| bool CodeGenerator::TryMarkGrammarRuleAsSafe(GrammarRule& rule) { |
| return TryMarkProductionRuleVecAsSafe(rule.productions); |
| } |
| |
| void CodeGenerator::CalculateFallBackIndex(std::vector<GrammarRule>& rules) { |
| std::vector<bool> safe_rule_indexes(rules.size(), false); |
| bool has_change = true; |
| do { |
| has_change = false; |
| for (size_t i = 0; i < rules.size(); ++i) { |
| if (safe_rule_indexes[i]) continue; |
| GrammarRule& rule = rules[i]; |
| if (TryMarkGrammarRuleAsSafe(rule)) { |
| has_change = true; |
| safe_rule_indexes[i] = true; |
| MarkSymbolAsSafe(rule.symbol_name); |
| } |
| } |
| } while (has_change); |
| |
| for (size_t i = 0; i < safe_rule_indexes.size(); ++i) { |
| FUZZTEST_INTERNAL_CHECK(safe_rule_indexes[i], "Some node is not safe!"); |
| } |
| |
| // Ensure that every sub-block is marked safe. For example, a grammar rule |
| // is `expr: Literal | (expr '+' expr)`. This rule will be marked as safe |
| // with fallback index as 0. However, the sub-block `(expr '+' expr)` is |
| // not marked as safe yet. During code generation, it requires every |
| // sub-block that is a variant must have a fallback index. Therefore, we |
| // do an extra run of marking. |
| for (GrammarRule& rule : rules) { |
| TryMarkGrammarRuleAsSafe(rule); |
| } |
| } |
| |
| // Helper functions. |
| |
| // Get the name of the generated class for the block. |
| std::string CodeGenerator::GetClassName(const Block& block) { |
| switch (block.element.index()) { |
| case BlockType::kTerminal: { |
| const Terminal& terminal = std::get<BlockType::kTerminal>(block.element); |
| return terminal.type == TerminalType::kStringLiteral |
| ? GetClassNameForLiteral(terminal.content) |
| : GetClassNameForCharSet(terminal.content); |
| } |
| case BlockType::kNonTerminal: |
| return GetClassNameForSymbol( |
| std::get<BlockType::kNonTerminal>(block.element).name); |
| default: |
| FUZZTEST_INTERNAL_CHECK(false, "A sub-block doesn't have a name!"); |
| } |
| return ""; |
| } |
| |
| std::string CodeGenerator::GetClassNameForSymbol(std::string id) { |
| FUZZTEST_INTERNAL_CHECK(!id.empty(), "Empty node name!"); |
| id[0] = toupper(id[0]); |
| if (std::isdigit(id.back())) { |
| return id; |
| } else { |
| return absl::StrFormat("%sNode", id); |
| } |
| } |
| |
| std::string CodeGenerator::GetClassNameForLiteral(std::string_view s) { |
| if (literal_node_ids_.find(s) == literal_node_ids_.end()) { |
| literal_node_ids_[s] = |
| absl::StrFormat("Literal%d", literal_node_ids_.size()); |
| } |
| return literal_node_ids_[s]; |
| } |
| |
| std::string CodeGenerator::GetClassNameForCharSet(std::string_view s) { |
| if (charset_node_ids_.find(s) == charset_node_ids_.end()) { |
| charset_node_ids_[s] = |
| absl::StrFormat("CharSet%d", charset_node_ids_.size()); |
| } |
| return charset_node_ids_[s]; |
| } |
| }; // namespace fuzztest::internal::grammar |