No public description

PiperOrigin-RevId: 761483512
diff --git a/domain_tests/aggregate_combinators_test.cc b/domain_tests/aggregate_combinators_test.cc
index 4e4c175..31bb883 100644
--- a/domain_tests/aggregate_combinators_test.cc
+++ b/domain_tests/aggregate_combinators_test.cc
@@ -207,27 +207,27 @@
 
 TEST(VariantOf, InitGenerateValidValues) {
   using X = std::variant<int, int, std::string>;
-  Domain<X> domain =
-      VariantOf(ElementOf({5, 10}), InRange(50, 500), PrintableAsciiString());
+  Domain<X> domain = VariantOf(
+      ElementOf({5, 10}), InRange(50, 500),
+      PrintableAsciiString().WithSeeds({"goober", "doober", "skoober"}));
   absl::BitGen bitgen;
   std::vector<absl::flat_hash_set<int>> int_unique_value_arr(2);
   absl::flat_hash_set<std::string> string_unique_values;
-  constexpr int n = 20;
   while (int_unique_value_arr[0].size() < 2 ||
-         int_unique_value_arr[1].size() < n ||
-         string_unique_values.size() < n) {
-    X value = Value(domain, bitgen).user_value;
-    if (value.index() == 0) {
-      int val = std::get<0>(value);
+         int_unique_value_arr[1].size() < 20 ||
+         string_unique_values.size() < 4) {
+    auto value = Value(domain, bitgen);
+    if (value.user_value.index() == 0) {
+      int val = std::get<0>(value.user_value);
       EXPECT_THAT(val, AnyOf(5, 10));
       int_unique_value_arr[0].insert(val);
-    } else if (value.index() == 1) {
-      int val = std::get<1>(value);
+    } else if (value.user_value.index() == 1) {
+      int val = std::get<1>(value.user_value);
       EXPECT_LE(val, 500);
       EXPECT_GE(val, 50);
       int_unique_value_arr[1].insert(val);
     } else {
-      string_unique_values.insert(std::get<2>(value));
+      string_unique_values.insert(std::get<2>(value.user_value));
     }
   }
 }
diff --git a/domain_tests/arbitrary_domains_protobuf_test.cc b/domain_tests/arbitrary_domains_protobuf_test.cc
index b2fe236..b8a67c1 100644
--- a/domain_tests/arbitrary_domains_protobuf_test.cc
+++ b/domain_tests/arbitrary_domains_protobuf_test.cc
@@ -571,8 +571,7 @@
               },
               1)
           .WithFieldsAlwaysSet();
-
-  EXPECT_THAT(GenerateInitialValues(domain, 1000),
+  EXPECT_THAT(GenerateNonUniqueValues(domain, 100),
               Contains(ResultOf(
                   [](const Value<Domain<TestProtobuf>>& val) {
                     return val.user_value.rep_i32_size();
diff --git a/domain_tests/container_combinators_test.cc b/domain_tests/container_combinators_test.cc
index 9416cae..1d83375 100644
--- a/domain_tests/container_combinators_test.cc
+++ b/domain_tests/container_combinators_test.cc
@@ -324,6 +324,20 @@
   VerifyRoundTripThroughConversion(GenerateValues(domain), domain);
 }
 
+// Reproducer for b/394904065
+TEST(UniqueElementsVectorOf, GeneratesValuesFromInnerDomainWithMaxSize) {
+  const int kNumElementsThatWillWork = 3 /* num chars */ + 1 /* empty string */
+                                       + 1;
+  auto string_domain =
+      StringOf(fuzztest::InRange<char>('a', 'c')).WithMaxSize(400);
+  auto domain =
+      UniqueElementsVectorOf(string_domain).WithSize(kNumElementsThatWillWork);
+  for (const auto& value : GenerateValues(domain)) {
+    ASSERT_THAT(value.user_value, Each(MatchesRegex("[abc]*")));
+    ASSERT_THAT(value.user_value, SizeIs(kNumElementsThatWillWork));
+  }
+}
+
 TEST(ContainerCombinatorTest, ArrayOfOne) {
   // A domain of std::array<T, 1> values can be defined in two ways:
   auto with_explicit_size = ArrayOf<1>(InRange(0.0, 1.0));
diff --git a/domain_tests/container_test.cc b/domain_tests/container_test.cc
index d27bdd9..2b00ae6 100644
--- a/domain_tests/container_test.cc
+++ b/domain_tests/container_test.cc
@@ -108,12 +108,15 @@
 
     ASSERT_THAT(v.user_value, size_match);
     sizes.insert(v.user_value.size());
-    v.Mutate(domain, bitgen, {}, false);
-    ASSERT_THAT(v.user_value, size_match);
+    for (int j = 0; j < 10; ++j) {
+      v.Mutate(domain, bitgen, {}, false);
+      ASSERT_THAT(v.user_value, size_match);
+      sizes.insert(v.user_value.size());
+    }
 
     // Mutating the value can reach the max, but only for small max sizes
     // because it would otherwise take too long.
-    if (max_size <= 10) {
+    if (max_size <= v.user_value.size() + 10) {
       auto max_v = v;
       while (max_v.user_value.size() < max_size) {
         v.Mutate(domain, bitgen, {}, false);
diff --git a/domain_tests/domain_testing.h b/domain_tests/domain_testing.h
index e8b9009..6bddb95 100644
--- a/domain_tests/domain_testing.h
+++ b/domain_tests/domain_testing.h
@@ -267,10 +267,12 @@
   absl::BitGen bitgen;
 
   absl::flat_hash_set<Value<Domain>> seeds;
-  // Make sure we can make some unique seeds.
-  // Randomness might create duplicates so keep going until we got them.
   while (seeds.size() < num_seeds) {
-    seeds.insert(Value(domain, bitgen));
+    auto value = Value(domain, bitgen);
+    while (!seeds.insert(value).second) {
+      value.Mutate(domain, bitgen, domain_implementor::MutationMetadata(),
+                   /*only_shrink=*/false);
+    }
   }
 
   auto values = seeds;
@@ -326,7 +328,8 @@
   absl::BitGen bitgen;
   values.reserve(n);
   for (int i = 0; i < n; ++i) {
-    values.push_back(Value(domain, bitgen));
+    auto value = Value(domain, bitgen);
+    values.push_back(std::move(value));
   }
   return values;
 }
diff --git a/domain_tests/map_filter_combinator_test.cc b/domain_tests/map_filter_combinator_test.cc
index ebf66ed..ee5c84a 100644
--- a/domain_tests/map_filter_combinator_test.cc
+++ b/domain_tests/map_filter_combinator_test.cc
@@ -275,6 +275,8 @@
   // Generate something shrinkable
   while (!value.has_value() || value->user_value.empty()) {
     value = Value(domain, bitgen);
+    value->Mutate(domain, bitgen, domain_implementor::MutationMetadata(),
+                  /*only_shrink=*/false);
   }
   auto mutated = value->corpus_value;
   while (!domain.GetValue(mutated).empty()) {
@@ -336,6 +338,17 @@
   EXPECT_THAT(seen, UnorderedElementsAre(2, 4, 6, 8, 10));
 }
 
+TEST(Filter, CanFilterInitCallsOnUtf8String) {
+  Domain<std::string> domain =
+      Filter([](const std::string& s) { return !s.empty(); }, Utf8String());
+  absl::BitGen bitgen;
+  Set<std::string> seen;
+  while (seen.size() < 2) {
+    seen.insert(Value(domain, bitgen).user_value);
+  }
+  EXPECT_THAT(seen, Not(IsEmpty()));
+}
+
 TEST(Filter, CanFilterMutateCalls) {
   Domain<int> domain = Filter([](int i) { return i % 2 == 0; }, InRange(1, 10));
   absl::BitGen bitgen;
diff --git a/domain_tests/string_domains_test.cc b/domain_tests/string_domains_test.cc
index 0ad0906..43db9eb 100644
--- a/domain_tests/string_domains_test.cc
+++ b/domain_tests/string_domains_test.cc
@@ -67,7 +67,12 @@
   Set<TypeParam> unique;
 
   for (int i = 0; i < 100; ++i) {
-    values.emplace_back(domain, bitgen);
+    auto value = Value(domain, bitgen);
+    for (int j = 0; j < 10; ++j) {
+      value.Mutate(domain, bitgen, domain_implementor::MutationMetadata(),
+                   /*only_shrink=*/false);
+    }
+    values.emplace_back(value, domain);
     unique.insert(values.back().user_value);
   }
 
@@ -163,8 +168,10 @@
 
 TEST(Domain, Utf8StringWorksWithSeeds) {
   auto domain = Utf8String().WithSeeds({"\u0414\u0430!\n"});
-  EXPECT_THAT(GenerateValues(domain),
-              Contains(Value(domain, "\u0414\u0430!\n")));
+  EXPECT_THAT(
+      GenerateInitialValues(
+          domain, IterationsToHitAll(/*num_cases=*/1, /*hit_probability=*/0.5)),
+      Contains(Value(domain, "\u0414\u0430!\n")));
 }
 
 TEST(Domain, Utf8StringIgnoresInvalideSeeds) {
diff --git a/fuzztest/internal/domains/container_of_impl.h b/fuzztest/internal/domains/container_of_impl.h
index 89698a1..0c18923 100644
--- a/fuzztest/internal/domains/container_of_impl.h
+++ b/fuzztest/internal/domains/container_of_impl.h
@@ -353,21 +353,6 @@
  protected:
   InnerDomainT inner_;
 
-  size_t ChooseRandomInitialSize(absl::BitGenRef prng) {
-    // The container size should not be empty (unless max_size_ = 0) because the
-    // initialization should be random if possible.
-    // TODO(changochen): Increase the number of generated elements.
-    // Currently we make container generate zero or one element to avoid
-    // infinite recursion in recursive data structures. For the example, we want
-    // to build a domain for `struct X{ int leaf; vector<X> recursive`, the
-    // expected generated length is `E(X) = E(leaf) + E(recursive)`. If the
-    // container generate `0-10` elements when calling `Init`, then
-    // `E(recursive) =  4.5 E(X)`, which will make `E(X) = Infinite`.
-    // Make some smallish random seed containers.
-    return absl::Uniform(prng, min_size(),
-                         std::min(max_size() + 1, min_size() + 2));
-  }
-
   size_t min_size() const { return min_size_; }
   size_t max_size() const {
     return max_size_.value_or(std::max(min_size_, kDefaultContainerMaxSize));
@@ -431,10 +416,9 @@
 
   corpus_type Init(absl::BitGenRef prng) {
     if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
-    const size_t size = this->ChooseRandomInitialSize(prng);
 
     corpus_type val;
-    Grow(val, prng, size, 10000);
+    Grow(val, prng, this->min_size(), 10000);
     if (val.size() < this->min_size()) {
       // We tried to make a container with the minimum specified size and we
       // could not after a lot of attempts. This could be caused by an
@@ -468,17 +452,6 @@
     Grow(val, prng, 1, kFailuresAllowed);
   }
 
-  auto GetRandomInnerValue(absl::BitGenRef prng) {
-    constexpr int kMaxMutations = 100;
-    const int num_mutations = absl::Zipf(prng, /*hi=*/kMaxMutations);
-    auto element = this->inner_.Init(prng);
-    for (int i = 0; i < num_mutations; ++i) {
-      this->inner_.Mutate(element, prng, domain_implementor::MutationMetadata(),
-                          /*only_shrink=*/false);
-    }
-    return element;
-  }
-
   // Try to grow `val` by `n` elements.
   void Grow(corpus_type& val, absl::BitGenRef prng, size_t n,
             size_t failures_allowed) {
@@ -492,7 +465,7 @@
     auto real_value = this->GetValue(val);
     const size_t final_size = real_value.size() + n;
     while (real_value.size() < final_size) {
-      auto new_element = GetRandomInnerValue(prng);
+      auto new_element = this->inner_.GetRandomCorpusValue(prng);
       if (real_value.insert(this->inner_.GetValue(new_element)).second) {
         val.push_back(std::move(new_element));
       } else {
@@ -550,10 +523,9 @@
 
   corpus_type Init(absl::BitGenRef prng) {
     if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
-    const size_t size = this->ChooseRandomInitialSize(prng);
     corpus_type val;
-    while (val.size() < size) {
-      val.insert(val.end(), GetRandomInnerValue(prng));
+    while (val.size() < this->min_size()) {
+      val.insert(val.end(), this->inner_.GetRandomCorpusValue(prng));
     }
     return val;
   }
@@ -587,17 +559,6 @@
                this->inner_.Init(prng));
   }
 
-  auto GetRandomInnerValue(absl::BitGenRef prng) {
-    constexpr int kMaxMutations = 100;
-    int num_mutations = absl::Zipf<int>(prng, /*hi=*/kMaxMutations);
-    auto element = this->inner_.Init(prng);
-    for (int i = 0; i < num_mutations; ++i) {
-      this->inner_.Mutate(element, prng, domain_implementor::MutationMetadata(),
-                          /*only_shrink=*/false);
-    }
-    return element;
-  }
-
   void MutateElement(corpus_type&, absl::BitGenRef prng,
                      const domain_implementor::MutationMetadata& metadata,
                      bool only_shrink, typename corpus_type::iterator it) {
diff --git a/fuzztest/internal/domains/domain.h b/fuzztest/internal/domains/domain.h
index a384e8a..1d38319 100644
--- a/fuzztest/internal/domains/domain.h
+++ b/fuzztest/internal/domains/domain.h
@@ -183,6 +183,12 @@
     return Mutate(corpus_value, prng, {}, only_shrink);
   }
 
+  // See the important notes on GetRandomValue() above on caveats of getting
+  // random values.
+  corpus_type GetRandomCorpusValue(absl::BitGenRef prng) {
+    return inner_->UntypedGetRandomCorpusValue(prng);
+  }
+
   // The methods below are responsible for transforming between the above
   // described three types that domains deal with. Here's a quick overview:
   //
diff --git a/fuzztest/internal/domains/domain_base.h b/fuzztest/internal/domains/domain_base.h
index 4243dcf..28faefd 100644
--- a/fuzztest/internal/domains/domain_base.h
+++ b/fuzztest/internal/domains/domain_base.h
@@ -203,21 +203,6 @@
     return derived();
   }
 
- protected:
-  // `Derived::Init()` can use this to sample seeds for this domain.
-  std::optional<CorpusType> MaybeGetRandomSeed(absl::BitGenRef prng) {
-    if (seed_provider_ != nullptr) {
-      WithSeeds(std::invoke(seed_provider_));
-      seed_provider_ = nullptr;
-    }
-
-    static constexpr double kProbabilityToReturnSeed = 0.5;
-    if (seeds_.empty() || !absl::Bernoulli(prng, kProbabilityToReturnSeed)) {
-      return std::nullopt;
-    }
-    return seeds_[fuzztest::internal::ChooseOffset(seeds_.size(), prng)];
-  }
-
   // Default implementation of GetRandomCorpusValue() without guarantees on the
   // distribution of the returned values. If possible, the derived domain should
   // override this with a better implementation.
@@ -235,6 +220,21 @@
     return corpus_val;
   }
 
+ protected:
+  // `Derived::Init()` can use this to sample seeds for this domain.
+  std::optional<CorpusType> MaybeGetRandomSeed(absl::BitGenRef prng) {
+    if (seed_provider_ != nullptr) {
+      WithSeeds(std::invoke(seed_provider_));
+      seed_provider_ = nullptr;
+    }
+
+    static constexpr double kProbabilityToReturnSeed = 0.5;
+    if (seeds_.empty() || !absl::Bernoulli(prng, kProbabilityToReturnSeed)) {
+      return std::nullopt;
+    }
+    return seeds_[absl::Uniform(prng, 0u, seeds_.size())];
+  }
+
  private:
   Derived& derived() { return static_cast<Derived&>(*this); }
   const Derived& derived() const { return static_cast<const Derived&>(*this); }
diff --git a/fuzztest/internal/domains/domain_type_erasure.h b/fuzztest/internal/domains/domain_type_erasure.h
index 2c38c3f..2dd3b08 100644
--- a/fuzztest/internal/domains/domain_type_erasure.h
+++ b/fuzztest/internal/domains/domain_type_erasure.h
@@ -74,6 +74,8 @@
   virtual uint64_t UntypedMutateSelectedField(
       GenericDomainCorpusType&, absl::BitGenRef,
       const domain_implementor::MutationMetadata&, bool, uint64_t) = 0;
+  virtual GenericDomainCorpusType UntypedGetRandomCorpusValue(
+      absl::BitGenRef prng) = 0;
   virtual GenericDomainValueType UntypedGetValue(
       const GenericDomainCorpusType& v) const = 0;
   virtual void UntypedPrintCorpusValue(
@@ -167,6 +169,12 @@
     }
   }
 
+  GenericDomainCorpusType UntypedGetRandomCorpusValue(
+      absl::BitGenRef prng) final {
+    return GenericDomainCorpusType(std::in_place_type<CorpusType>,
+                                   domain_.GetRandomCorpusValue(prng));
+  }
+
   std::optional<GenericDomainCorpusType> UntypedParseCorpus(
       const IRObject& obj) const final {
     if (auto res = domain_.ParseCorpus(obj)) {
diff --git a/fuzztest/internal/domains/filter_impl.h b/fuzztest/internal/domains/filter_impl.h
index a9b3130..f50ed56 100644
--- a/fuzztest/internal/domains/filter_impl.h
+++ b/fuzztest/internal/domains/filter_impl.h
@@ -44,7 +44,7 @@
   corpus_type Init(absl::BitGenRef prng) {
     if (auto seed = this->MaybeGetRandomSeed(prng)) return *seed;
     while (true) {
-      auto v = inner_.Init(prng);
+      auto v = inner_.GetRandomCorpusValue(prng);
       if (RunFilter(v)) return v;
     }
   }