Do hashtablez sampling on the first insertion into an empty SOO hashtable.

When sampling triggers, we skip SOO and allocate a backing array. We must do this because the HashtablezInfoHandle is part of the heap allocation (along with the control bytes and slots). By default, we sample 1 in ~1024 hashtables when sampling is enabled. This will impact performance because (a) we won't benefit from SOO so we would have worse data locality (more cache/TLB misses), and (b) the backing array capacity will be 3 instead of 1 so (b.1) we skip the rehash after the second insertion and (b.2) we potentially waste up to two slots worth of memory.

We also add an soo_capacity field to HashtablezInfo to allow for distinguishing which sampled tables may otherwise have been SOO - this will allow us to know approximately what fraction of tables are in SOO mode.

PiperOrigin-RevId: 617252334
Change-Id: Ib48b7a4870bd74ea3ba923ed8f350a3b75dbb7d3
diff --git a/absl/container/BUILD.bazel b/absl/container/BUILD.bazel
index 0de4526..c1df8d8 100644
--- a/absl/container/BUILD.bazel
+++ b/absl/container/BUILD.bazel
@@ -701,8 +701,10 @@
         "//absl/base:config",
         "//absl/base:core_headers",
         "//absl/base:prefetch",
+        "//absl/functional:function_ref",
         "//absl/hash",
         "//absl/log",
+        "//absl/log:check",
         "//absl/memory",
         "//absl/meta:type_traits",
         "//absl/strings",
diff --git a/absl/container/CMakeLists.txt b/absl/container/CMakeLists.txt
index 4b08e6a..7fa2e91 100644
--- a/absl/container/CMakeLists.txt
+++ b/absl/container/CMakeLists.txt
@@ -753,11 +753,13 @@
     ${ABSL_TEST_COPTS}
   DEPS
     absl::base
+    absl::check
     absl::config
     absl::container_memory
     absl::core_headers
     absl::flat_hash_map
     absl::flat_hash_set
+    absl::function_ref
     absl::hash
     absl::hash_function_defaults
     absl::hash_policy_testing
diff --git a/absl/container/internal/hashtablez_sampler.cc b/absl/container/internal/hashtablez_sampler.cc
index b21beef..e17ba14 100644
--- a/absl/container/internal/hashtablez_sampler.cc
+++ b/absl/container/internal/hashtablez_sampler.cc
@@ -18,13 +18,18 @@
 #include <atomic>
 #include <cassert>
 #include <cmath>
+#include <cstddef>
+#include <cstdint>
 #include <functional>
 #include <limits>
 
 #include "absl/base/attributes.h"
 #include "absl/base/config.h"
+#include "absl/base/internal/per_thread_tls.h"
 #include "absl/base/internal/raw_logging.h"
+#include "absl/base/macros.h"
 #include "absl/base/no_destructor.h"
+#include "absl/base/optimization.h"
 #include "absl/debugging/stacktrace.h"
 #include "absl/memory/memory.h"
 #include "absl/profiling/internal/exponential_biased.h"
@@ -73,7 +78,8 @@
 HashtablezInfo::~HashtablezInfo() = default;
 
 void HashtablezInfo::PrepareForSampling(int64_t stride,
-                                        size_t inline_element_size_value) {
+                                        size_t inline_element_size_value,
+                                        uint16_t soo_capacity_value) {
   capacity.store(0, std::memory_order_relaxed);
   size.store(0, std::memory_order_relaxed);
   num_erases.store(0, std::memory_order_relaxed);
@@ -93,6 +99,7 @@
   depth = absl::GetStackTrace(stack, HashtablezInfo::kMaxStackDepth,
                               /* skip_count= */ 0);
   inline_element_size = inline_element_size_value;
+  soo_capacity = soo_capacity_value;
 }
 
 static bool ShouldForceSampling() {
@@ -116,12 +123,12 @@
 }
 
 HashtablezInfo* SampleSlow(SamplingState& next_sample,
-                           size_t inline_element_size) {
+                           size_t inline_element_size, uint16_t soo_capacity) {
   if (ABSL_PREDICT_FALSE(ShouldForceSampling())) {
     next_sample.next_sample = 1;
     const int64_t old_stride = exchange(next_sample.sample_stride, 1);
-    HashtablezInfo* result =
-        GlobalHashtablezSampler().Register(old_stride, inline_element_size);
+    HashtablezInfo* result = GlobalHashtablezSampler().Register(
+        old_stride, inline_element_size, soo_capacity);
     return result;
   }
 
@@ -151,10 +158,11 @@
   // that case.
   if (first) {
     if (ABSL_PREDICT_TRUE(--next_sample.next_sample > 0)) return nullptr;
-    return SampleSlow(next_sample, inline_element_size);
+    return SampleSlow(next_sample, inline_element_size, soo_capacity);
   }
 
-  return GlobalHashtablezSampler().Register(old_stride, inline_element_size);
+  return GlobalHashtablezSampler().Register(old_stride, inline_element_size,
+                                            soo_capacity);
 #endif
 }
 
diff --git a/absl/container/internal/hashtablez_sampler.h b/absl/container/internal/hashtablez_sampler.h
index e41ee2d..4e11968 100644
--- a/absl/container/internal/hashtablez_sampler.h
+++ b/absl/container/internal/hashtablez_sampler.h
@@ -40,15 +40,20 @@
 #define ABSL_CONTAINER_INTERNAL_HASHTABLEZ_SAMPLER_H_
 
 #include <atomic>
+#include <cstddef>
+#include <cstdint>
 #include <functional>
 #include <memory>
 #include <vector>
 
+#include "absl/base/attributes.h"
 #include "absl/base/config.h"
 #include "absl/base/internal/per_thread_tls.h"
 #include "absl/base/optimization.h"
+#include "absl/base/thread_annotations.h"
 #include "absl/profiling/internal/sample_recorder.h"
 #include "absl/synchronization/mutex.h"
+#include "absl/time/time.h"
 #include "absl/utility/utility.h"
 
 namespace absl {
@@ -67,7 +72,8 @@
 
   // Puts the object into a clean state, fills in the logically `const` members,
   // blocking for any readers that are currently sampling the object.
-  void PrepareForSampling(int64_t stride, size_t inline_element_size_value)
+  void PrepareForSampling(int64_t stride, size_t inline_element_size_value,
+                          uint16_t soo_capacity_value)
       ABSL_EXCLUSIVE_LOCKS_REQUIRED(init_mu);
 
   // These fields are mutated by the various Record* APIs and need to be
@@ -91,8 +97,13 @@
   static constexpr int kMaxStackDepth = 64;
   absl::Time create_time;
   int32_t depth;
+  // The SOO capacity for this table in elements (not bytes). Note that sampled
+  // tables are never SOO because we need to store the infoz handle on the heap.
+  // Tables that would be SOO if not sampled should have: soo_capacity > 0 &&
+  // size <= soo_capacity && max_reserve <= soo_capacity.
+  uint16_t soo_capacity;
   void* stack[kMaxStackDepth];
-  size_t inline_element_size;  // How big is the slot?
+  size_t inline_element_size;  // How big is the slot in bytes?
 };
 
 void RecordRehashSlow(HashtablezInfo* info, size_t total_probe_length);
@@ -117,7 +128,7 @@
 };
 
 HashtablezInfo* SampleSlow(SamplingState& next_sample,
-                           size_t inline_element_size);
+                           size_t inline_element_size, uint16_t soo_capacity);
 void UnsampleSlow(HashtablezInfo* info);
 
 #if defined(ABSL_INTERNAL_HASHTABLEZ_SAMPLE)
@@ -204,16 +215,16 @@
 extern ABSL_PER_THREAD_TLS_KEYWORD SamplingState global_next_sample;
 #endif  // defined(ABSL_INTERNAL_HASHTABLEZ_SAMPLE)
 
-// Returns an RAII sampling handle that manages registration and unregistation
-// with the global sampler.
+// Returns a sampling handle.
 inline HashtablezInfoHandle Sample(
-    size_t inline_element_size ABSL_ATTRIBUTE_UNUSED) {
+    ABSL_ATTRIBUTE_UNUSED size_t inline_element_size,
+    ABSL_ATTRIBUTE_UNUSED uint16_t soo_capacity) {
 #if defined(ABSL_INTERNAL_HASHTABLEZ_SAMPLE)
   if (ABSL_PREDICT_TRUE(--global_next_sample.next_sample > 0)) {
     return HashtablezInfoHandle(nullptr);
   }
   return HashtablezInfoHandle(
-      SampleSlow(global_next_sample, inline_element_size));
+      SampleSlow(global_next_sample, inline_element_size, soo_capacity));
 #else
   return HashtablezInfoHandle(nullptr);
 #endif  // !ABSL_PER_THREAD_TLS
diff --git a/absl/container/internal/hashtablez_sampler_test.cc b/absl/container/internal/hashtablez_sampler_test.cc
index 8ebb08d..a80f530 100644
--- a/absl/container/internal/hashtablez_sampler_test.cc
+++ b/absl/container/internal/hashtablez_sampler_test.cc
@@ -15,8 +15,12 @@
 #include "absl/container/internal/hashtablez_sampler.h"
 
 #include <atomic>
+#include <cassert>
+#include <cstddef>
+#include <cstdint>
 #include <limits>
 #include <random>
+#include <vector>
 
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
@@ -67,7 +71,7 @@
 HashtablezInfo* Register(HashtablezSampler* s, size_t size) {
   const int64_t test_stride = 123;
   const size_t test_element_size = 17;
-  auto* info = s->Register(test_stride, test_element_size);
+  auto* info = s->Register(test_stride, test_element_size, /*soo_capacity=*/0);
   assert(info != nullptr);
   info->size.store(size);
   return info;
@@ -79,7 +83,8 @@
   const size_t test_element_size = 17;
   HashtablezInfo info;
   absl::MutexLock l(&info.init_mu);
-  info.PrepareForSampling(test_stride, test_element_size);
+  info.PrepareForSampling(test_stride, test_element_size,
+                          /*soo_capacity_value=*/1);
 
   EXPECT_EQ(info.capacity.load(), 0);
   EXPECT_EQ(info.size.load(), 0);
@@ -94,6 +99,7 @@
   EXPECT_GE(info.create_time, test_start);
   EXPECT_EQ(info.weight, test_stride);
   EXPECT_EQ(info.inline_element_size, test_element_size);
+  EXPECT_EQ(info.soo_capacity, 1);
 
   info.capacity.store(1, std::memory_order_relaxed);
   info.size.store(1, std::memory_order_relaxed);
@@ -106,7 +112,8 @@
   info.max_reserve.store(1, std::memory_order_relaxed);
   info.create_time = test_start - absl::Hours(20);
 
-  info.PrepareForSampling(test_stride * 2, test_element_size);
+  info.PrepareForSampling(test_stride * 2, test_element_size,
+                          /*soo_capacity_value=*/0);
   EXPECT_EQ(info.capacity.load(), 0);
   EXPECT_EQ(info.size.load(), 0);
   EXPECT_EQ(info.num_erases.load(), 0);
@@ -120,6 +127,7 @@
   EXPECT_EQ(info.weight, 2 * test_stride);
   EXPECT_EQ(info.inline_element_size, test_element_size);
   EXPECT_GE(info.create_time, test_start);
+  EXPECT_EQ(info.soo_capacity, 0);
 }
 
 TEST(HashtablezInfoTest, RecordStorageChanged) {
@@ -127,7 +135,8 @@
   absl::MutexLock l(&info.init_mu);
   const int64_t test_stride = 21;
   const size_t test_element_size = 19;
-  info.PrepareForSampling(test_stride, test_element_size);
+  info.PrepareForSampling(test_stride, test_element_size,
+                          /*soo_capacity_value=*/0);
   RecordStorageChangedSlow(&info, 17, 47);
   EXPECT_EQ(info.size.load(), 17);
   EXPECT_EQ(info.capacity.load(), 47);
@@ -141,7 +150,8 @@
   absl::MutexLock l(&info.init_mu);
   const int64_t test_stride = 25;
   const size_t test_element_size = 23;
-  info.PrepareForSampling(test_stride, test_element_size);
+  info.PrepareForSampling(test_stride, test_element_size,
+                          /*soo_capacity_value=*/0);
   EXPECT_EQ(info.max_probe_length.load(), 0);
   RecordInsertSlow(&info, 0x0000FF00, 6 * kProbeLength);
   EXPECT_EQ(info.max_probe_length.load(), 6);
@@ -165,7 +175,8 @@
   const size_t test_element_size = 29;
   HashtablezInfo info;
   absl::MutexLock l(&info.init_mu);
-  info.PrepareForSampling(test_stride, test_element_size);
+  info.PrepareForSampling(test_stride, test_element_size,
+                          /*soo_capacity_value=*/1);
   EXPECT_EQ(info.num_erases.load(), 0);
   EXPECT_EQ(info.size.load(), 0);
   RecordInsertSlow(&info, 0x0000FF00, 6 * kProbeLength);
@@ -174,6 +185,7 @@
   EXPECT_EQ(info.size.load(), 0);
   EXPECT_EQ(info.num_erases.load(), 1);
   EXPECT_EQ(info.inline_element_size, test_element_size);
+  EXPECT_EQ(info.soo_capacity, 1);
 }
 
 TEST(HashtablezInfoTest, RecordRehash) {
@@ -181,7 +193,8 @@
   const size_t test_element_size = 31;
   HashtablezInfo info;
   absl::MutexLock l(&info.init_mu);
-  info.PrepareForSampling(test_stride, test_element_size);
+  info.PrepareForSampling(test_stride, test_element_size,
+                          /*soo_capacity_value=*/0);
   RecordInsertSlow(&info, 0x1, 0);
   RecordInsertSlow(&info, 0x2, kProbeLength);
   RecordInsertSlow(&info, 0x4, kProbeLength);
@@ -201,6 +214,7 @@
   EXPECT_EQ(info.num_erases.load(), 0);
   EXPECT_EQ(info.num_rehashes.load(), 1);
   EXPECT_EQ(info.inline_element_size, test_element_size);
+  EXPECT_EQ(info.soo_capacity, 0);
 }
 
 TEST(HashtablezInfoTest, RecordReservation) {
@@ -208,7 +222,8 @@
   absl::MutexLock l(&info.init_mu);
   const int64_t test_stride = 35;
   const size_t test_element_size = 33;
-  info.PrepareForSampling(test_stride, test_element_size);
+  info.PrepareForSampling(test_stride, test_element_size,
+                          /*soo_capacity_value=*/0);
   RecordReservationSlow(&info, 3);
   EXPECT_EQ(info.max_reserve.load(), 3);
 
@@ -229,7 +244,8 @@
 
   for (int i = 0; i < 1000; ++i) {
     SamplingState next_sample = {0, 0};
-    HashtablezInfo* sample = SampleSlow(next_sample, test_element_size);
+    HashtablezInfo* sample = SampleSlow(next_sample, test_element_size,
+                                        /*soo_capacity=*/0);
     EXPECT_GT(next_sample.next_sample, 0);
     EXPECT_EQ(next_sample.next_sample, next_sample.sample_stride);
     EXPECT_NE(sample, nullptr);
@@ -244,7 +260,8 @@
 
   for (int i = 0; i < 1000; ++i) {
     SamplingState next_sample = {0, 0};
-    HashtablezInfo* sample = SampleSlow(next_sample, test_element_size);
+    HashtablezInfo* sample = SampleSlow(next_sample, test_element_size,
+                                        /*soo_capacity=*/0);
     EXPECT_GT(next_sample.next_sample, 0);
     EXPECT_EQ(next_sample.next_sample, next_sample.sample_stride);
     EXPECT_NE(sample, nullptr);
@@ -260,7 +277,8 @@
   int64_t total = 0;
   double sample_rate = 0.0;
   for (int i = 0; i < 1000000; ++i) {
-    HashtablezInfoHandle h = Sample(test_element_size);
+    HashtablezInfoHandle h = Sample(test_element_size,
+                                    /*soo_capacity=*/0);
     ++total;
     if (h.IsSampled()) {
       ++num_sampled;
@@ -275,7 +293,8 @@
   auto& sampler = GlobalHashtablezSampler();
   const int64_t test_stride = 41;
   const size_t test_element_size = 39;
-  HashtablezInfoHandle h(sampler.Register(test_stride, test_element_size));
+  HashtablezInfoHandle h(sampler.Register(test_stride, test_element_size,
+                                          /*soo_capacity=*/0));
   auto* info = HashtablezInfoHandlePeer::GetInfo(&h);
   info->hashes_bitwise_and.store(0x12345678, std::memory_order_relaxed);
 
@@ -358,11 +377,13 @@
       std::vector<HashtablezInfo*> infoz;
       while (!stop.HasBeenNotified()) {
         if (infoz.empty()) {
-          infoz.push_back(sampler.Register(sampling_stride, elt_size));
+          infoz.push_back(sampler.Register(sampling_stride, elt_size,
+                                           /*soo_capacity=*/0));
         }
         switch (std::uniform_int_distribution<>(0, 2)(gen)) {
           case 0: {
-            infoz.push_back(sampler.Register(sampling_stride, elt_size));
+            infoz.push_back(sampler.Register(sampling_stride, elt_size,
+                                             /*soo_capacity=*/0));
             break;
           }
           case 1: {
diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h
index 258458b..5818844 100644
--- a/absl/container/internal/raw_hash_set.h
+++ b/absl/container/internal/raw_hash_set.h
@@ -1775,17 +1775,48 @@
   }
 }
 
+template <typename CharAlloc>
+constexpr bool ShouldSampleHashtablezInfo() {
+  // Folks with custom allocators often make unwarranted assumptions about the
+  // behavior of their classes vis-a-vis trivial destructability and what
+  // calls they will or won't make.  Avoid sampling for people with custom
+  // allocators to get us out of this mess.  This is not a hard guarantee but
+  // a workaround while we plan the exact guarantee we want to provide.
+  return std::is_same<CharAlloc, std::allocator<char>>::value;
+}
+
+template <bool kSooEnabled>
+HashtablezInfoHandle SampleHashtablezInfo(size_t sizeof_slot,
+                                          size_t old_capacity, bool was_soo,
+                                          HashtablezInfoHandle forced_infoz,
+                                          CommonFields& c) {
+  if (forced_infoz.IsSampled()) return forced_infoz;
+  // In SOO, we sample on the first insertion so if this is an empty SOO case
+  // (e.g. when reserve is called), then we still need to sample.
+  if (kSooEnabled && was_soo && c.size() == 0) {
+    return Sample(sizeof_slot, SooCapacity());
+  }
+  // For non-SOO cases, we sample whenever the capacity is increasing from zero
+  // to non-zero.
+  if (!kSooEnabled && old_capacity == 0) {
+    return Sample(sizeof_slot, 0);
+  }
+  return c.infoz();
+}
+
 // Helper class to perform resize of the hash set.
 //
 // It contains special optimizations for small group resizes.
 // See GrowIntoSingleGroupShuffleControlBytes for details.
 class HashSetResizeHelper {
  public:
-  explicit HashSetResizeHelper(CommonFields& c, bool was_soo, bool had_soo_slot)
+  explicit HashSetResizeHelper(CommonFields& c, bool was_soo, bool had_soo_slot,
+                               HashtablezInfoHandle forced_infoz)
       : old_capacity_(c.capacity()),
         had_infoz_(c.has_infoz()),
         was_soo_(was_soo),
-        had_soo_slot_(had_soo_slot) {}
+        had_soo_slot_(had_soo_slot),
+        forced_infoz_(forced_infoz) {}
 
   // Optimized for small groups version of `find_first_non_full`.
   // Beneficial only right after calling `raw_hash_set::resize`.
@@ -1839,7 +1870,7 @@
   // Reads `capacity` and updates all other fields based on the result of
   // the allocation.
   //
-  // It also may do the folowing actions:
+  // It also may do the following actions:
   // 1. initialize control bytes
   // 2. initialize slots
   // 3. deallocate old slots.
@@ -1869,26 +1900,15 @@
   //
   //  Returns IsGrowingIntoSingleGroupApplicable result to avoid recomputation.
   template <typename Alloc, size_t SizeOfSlot, bool TransferUsesMemcpy,
-            size_t AlignOfSlot>
+            bool SooEnabled, size_t AlignOfSlot>
   ABSL_ATTRIBUTE_NOINLINE bool InitializeSlots(CommonFields& c, Alloc alloc,
                                                ctrl_t soo_slot_h2) {
     assert(c.capacity());
-    // Folks with custom allocators often make unwarranted assumptions about the
-    // behavior of their classes vis-a-vis trivial destructability and what
-    // calls they will or won't make.  Avoid sampling for people with custom
-    // allocators to get us out of this mess.  This is not a hard guarantee but
-    // a workaround while we plan the exact guarantee we want to provide.
-    const size_t sample_size =
-        (std::is_same<Alloc, std::allocator<char>>::value &&
-         (was_soo_ || old_capacity_ == 0))
-            ? SizeOfSlot
-            : 0;
-    // TODO(b/289225379): when SOO is enabled, we should still sample on first
-    // insertion and if Sample is non-null, then we should force a heap
-    // allocation. Note that we'll also have to force capacity of 3 so that
-    // is_soo() still works.
     HashtablezInfoHandle infoz =
-        sample_size > 0 ? Sample(sample_size) : c.infoz();
+        ShouldSampleHashtablezInfo<Alloc>()
+            ? SampleHashtablezInfo<SooEnabled>(SizeOfSlot, old_capacity_,
+                                               was_soo_, forced_infoz_, c)
+            : HashtablezInfoHandle{};
 
     const bool has_infoz = infoz.IsSampled();
     RawHashSetLayout layout(c.capacity(), AlignOfSlot, has_infoz);
@@ -1904,12 +1924,13 @@
 
     const bool grow_single_group =
         IsGrowingIntoSingleGroupApplicable(old_capacity_, layout.capacity());
-    if (was_soo_ && grow_single_group) {
+    if (SooEnabled && was_soo_ && grow_single_group) {
       InitControlBytesAfterSoo(c.control(), soo_slot_h2, layout.capacity());
       if (TransferUsesMemcpy && had_soo_slot_) {
         TransferSlotAfterSoo(c, SizeOfSlot);
       }
-    } else if (old_capacity_ != 0 && grow_single_group) {
+      // SooEnabled implies that old_capacity_ != 0.
+    } else if ((SooEnabled || old_capacity_ != 0) && grow_single_group) {
       if (TransferUsesMemcpy) {
         GrowSizeIntoSingleGroupTransferable(c, SizeOfSlot);
         DeallocateOld<AlignOfSlot>(alloc, SizeOfSlot);
@@ -1923,7 +1944,7 @@
     c.set_has_infoz(has_infoz);
     if (has_infoz) {
       infoz.RecordStorageChanged(c.size(), layout.capacity());
-      if (was_soo_ || grow_single_group || old_capacity_ == 0) {
+      if ((SooEnabled && was_soo_) || grow_single_group || old_capacity_ == 0) {
         infoz.RecordRehash(0);
       }
       c.set_infoz(infoz);
@@ -2063,6 +2084,8 @@
   bool had_infoz_;
   bool was_soo_;
   bool had_soo_slot_;
+  // Either null infoz or a pre-sampled forced infoz for SOO tables.
+  HashtablezInfoHandle forced_infoz_;
 };
 
 inline void PrepareInsertCommon(CommonFields& common) {
@@ -2545,6 +2568,8 @@
       assert(size == 1);
       common().set_full_soo();
       emplace_at(soo_iterator(), *that.begin());
+      const HashtablezInfoHandle infoz = try_sample_soo();
+      if (infoz.IsSampled()) resize_with_soo_infoz(infoz);
       return;
     }
     assert(!that.is_soo());
@@ -2682,8 +2707,7 @@
   size_t size() const { return common().size(); }
   size_t capacity() const {
     const size_t cap = common().capacity();
-    // Use local variables because compiler complains about using functions in
-    // assume.
+    // Compiler complains when using functions in assume so use local variables.
     ABSL_ATTRIBUTE_UNUSED static constexpr bool kEnabled = SooEnabled();
     ABSL_ATTRIBUTE_UNUSED static constexpr size_t kCapacity = SooCapacity();
     ABSL_ASSUME(!kEnabled || cap >= kCapacity);
@@ -3051,7 +3075,17 @@
                           SooEnabled());
         return;
       }
-      if (SooEnabled() && size() <= SooCapacity()) {
+      if (fits_in_soo(size())) {
+        // When the table is already sampled, we keep it sampled.
+        if (infoz().IsSampled()) {
+          const size_t kInitialSampledCapacity = NextCapacity(SooCapacity());
+          if (capacity() > kInitialSampledCapacity) {
+            resize(kInitialSampledCapacity);
+          }
+          // This asserts that we didn't lose sampling coverage in `resize`.
+          assert(infoz().IsSampled());
+          return;
+        }
         alignas(slot_type) unsigned char slot_space[sizeof(slot_type)];
         slot_type* tmp_slot = to_slot(slot_space);
         transfer(tmp_slot, begin().slot());
@@ -3322,6 +3356,15 @@
     }
   }
 
+  // Conditionally samples hashtablez for SOO tables. This should be called on
+  // insertion into an empty SOO table and in copy construction when the size
+  // can fit in SOO capacity.
+  inline HashtablezInfoHandle try_sample_soo() {
+    assert(is_soo());
+    if (!ShouldSampleHashtablezInfo<CharAlloc>()) return HashtablezInfoHandle{};
+    return Sample(sizeof(slot_type), SooCapacity());
+  }
+
   inline void destroy_slots() {
     assert(!is_soo());
     if (PolicyTraits::template destroy_is_trivial<Alloc>()) return;
@@ -3374,7 +3417,20 @@
   // HashSetResizeHelper::FindFirstNonFullAfterResize(
   //    common(), old_capacity, hash)
   // can be called right after `resize`.
-  ABSL_ATTRIBUTE_NOINLINE void resize(size_t new_capacity) {
+  void resize(size_t new_capacity) {
+    resize_impl(new_capacity, HashtablezInfoHandle{});
+  }
+
+  // As above, except that we also accept a pre-sampled, forced infoz for
+  // SOO tables, since they need to switch from SOO to heap in order to
+  // store the infoz.
+  void resize_with_soo_infoz(HashtablezInfoHandle forced_infoz) {
+    assert(forced_infoz.IsSampled());
+    resize_impl(NextCapacity(SooCapacity()), forced_infoz);
+  }
+
+  ABSL_ATTRIBUTE_NOINLINE void resize_impl(
+      size_t new_capacity, HashtablezInfoHandle forced_infoz) {
     assert(IsValidCapacity(new_capacity));
     assert(!fits_in_soo(new_capacity));
     const bool was_soo = is_soo();
@@ -3382,7 +3438,8 @@
     const ctrl_t soo_slot_h2 =
         had_soo_slot ? static_cast<ctrl_t>(H2(hash_of(soo_slot())))
                      : ctrl_t::kEmpty;
-    HashSetResizeHelper resize_helper(common(), was_soo, had_soo_slot);
+    HashSetResizeHelper resize_helper(common(), was_soo, had_soo_slot,
+                                      forced_infoz);
     // Initialize HashSetResizeHelper::old_heap_or_soo_. We can't do this in
     // HashSetResizeHelper constructor because it can't transfer slots when
     // transfer_uses_memcpy is false.
@@ -3400,7 +3457,7 @@
     const bool grow_single_group =
         resize_helper.InitializeSlots<CharAlloc, sizeof(slot_type),
                                       PolicyTraits::transfer_uses_memcpy(),
-                                      alignof(slot_type)>(
+                                      SooEnabled(), alignof(slot_type)>(
             common(), CharAlloc(alloc_ref()), soo_slot_h2);
 
     // In the SooEnabled() case, capacity is never 0 so we don't check.
@@ -3621,26 +3678,29 @@
     return move_elements_allocs_unequal(std::move(that));
   }
 
- protected:
-  // Attempts to find `key` in the table; if it isn't found, returns an iterator
-  // where the value can be inserted into, with the control byte already set to
-  // `key`'s H2. Returns a bool indicating whether an insertion can take place.
   template <class K>
-  std::pair<iterator, bool> find_or_prepare_insert(const K& key) {
-    if (is_soo()) {
-      if (empty()) {
+  std::pair<iterator, bool> find_or_prepare_insert_soo(const K& key) {
+    if (empty()) {
+      const HashtablezInfoHandle infoz = try_sample_soo();
+      if (infoz.IsSampled()) {
+        resize_with_soo_infoz(infoz);
+      } else {
         common().set_full_soo();
         return {soo_iterator(), true};
       }
-      if (PolicyTraits::apply(EqualElement<K>{key, eq_ref()},
-                              PolicyTraits::element(soo_slot()))) {
-        return {soo_iterator(), false};
-      }
+    } else if (PolicyTraits::apply(EqualElement<K>{key, eq_ref()},
+                                   PolicyTraits::element(soo_slot()))) {
+      return {soo_iterator(), false};
+    } else {
       resize(NextCapacity(SooCapacity()));
-      const size_t index =
-          PrepareInsertAfterSoo(hash_ref()(key), sizeof(slot_type), common());
-      return {iterator_at(index), true};
     }
+    const size_t index =
+        PrepareInsertAfterSoo(hash_ref()(key), sizeof(slot_type), common());
+    return {iterator_at(index), true};
+  }
+
+  template <class K>
+  std::pair<iterator, bool> find_or_prepare_insert_non_soo(const K& key) {
     prefetch_heap_block();
     auto hash = hash_ref()(key);
     auto seq = probe(common(), hash);
@@ -3660,6 +3720,16 @@
     return {iterator_at(prepare_insert(hash)), true};
   }
 
+ protected:
+  // Attempts to find `key` in the table; if it isn't found, returns an iterator
+  // where the value can be inserted into, with the control byte already set to
+  // `key`'s H2. Returns a bool indicating whether an insertion can take place.
+  template <class K>
+  std::pair<iterator, bool> find_or_prepare_insert(const K& key) {
+    if (is_soo()) return find_or_prepare_insert_soo(key);
+    return find_or_prepare_insert_non_soo(key);
+  }
+
   // Given the hash of a value not currently in the table, finds the next
   // viable slot index to insert it at.
   //
@@ -3769,13 +3839,13 @@
     return static_cast<slot_type*>(common().soo_data());
   }
   const slot_type* soo_slot() const {
-    return reinterpret_cast<raw_hash_set*>(this)->soo_slot();
+    return const_cast<raw_hash_set*>(this)->soo_slot();
   }
   iterator soo_iterator() {
     return {SooControl(), soo_slot(), common().generation_ptr()};
   }
   const_iterator soo_iterator() const {
-    return reinterpret_cast<raw_hash_set*>(this)->soo_iterator();
+    return const_cast<raw_hash_set*>(this)->soo_iterator();
   }
   HashtablezInfoHandle infoz() {
     assert(!is_soo());
diff --git a/absl/container/internal/raw_hash_set_test.cc b/absl/container/internal/raw_hash_set_test.cc
index 66baeb6..f00cef8 100644
--- a/absl/container/internal/raw_hash_set_test.cc
+++ b/absl/container/internal/raw_hash_set_test.cc
@@ -52,7 +52,9 @@
 #include "absl/container/internal/hashtable_debug.h"
 #include "absl/container/internal/hashtablez_sampler.h"
 #include "absl/container/internal/test_allocator.h"
+#include "absl/functional/function_ref.h"
 #include "absl/hash/hash.h"
+#include "absl/log/check.h"
 #include "absl/log/log.h"
 #include "absl/memory/memory.h"
 #include "absl/meta/type_traits.h"
@@ -2035,6 +2037,9 @@
 }
 
 TYPED_TEST_P(SooTest, ReplacingDeletedSlotDoesNotRehash) {
+  // We need to disable hashtablez to avoid issues related to SOO and sampling.
+  SetHashtablezEnabled(false);
+
   size_t n;
   {
     // Compute n such that n is the maximum number of elements before rehash.
@@ -2388,6 +2393,9 @@
 // Verify that pointers are invalidated as soon as a second element is inserted.
 // This prevents dependency on pointer stability on small tables.
 TYPED_TEST_P(SooTest, UnstablePointers) {
+  // We need to disable hashtablez to avoid issues related to SOO and sampling.
+  SetHashtablezEnabled(false);
+
   TypeParam table;
 
   const auto addr = [&](int i) {
@@ -2523,7 +2531,7 @@
   constexpr bool soo_enabled = std::is_same<SooIntTable, TypeParam>::value;
   // Enable the feature even if the prod default is off.
   SetHashtablezEnabled(true);
-  SetHashtablezSampleParameter(100);
+  SetHashtablezSampleParameter(100);  // Sample ~1% of tables.
 
   auto& sampler = GlobalHashtablezSampler();
   size_t start_size = 0;
@@ -2557,25 +2565,26 @@
   absl::flat_hash_map<size_t, int> observed_checksums;
   absl::flat_hash_map<ssize_t, int> reservations;
   end_size += sampler.Iterate([&](const HashtablezInfo& info) {
-    if (preexisting_info.count(&info) == 0) {
-      observed_checksums[info.hashes_bitwise_xor.load(
-          std::memory_order_relaxed)]++;
-      reservations[info.max_reserve.load(std::memory_order_relaxed)]++;
-    }
-    EXPECT_EQ(info.inline_element_size, sizeof(typename TypeParam::value_type));
     ++end_size;
+    if (preexisting_info.contains(&info)) return;
+    observed_checksums[info.hashes_bitwise_xor.load(
+        std::memory_order_relaxed)]++;
+    reservations[info.max_reserve.load(std::memory_order_relaxed)]++;
+    EXPECT_EQ(info.inline_element_size, sizeof(typename TypeParam::value_type));
+
+    if (soo_enabled) {
+      EXPECT_EQ(info.soo_capacity, SooCapacity());
+    } else {
+      EXPECT_EQ(info.soo_capacity, 0);
+    }
   });
 
+  // Expect that we sampled at the requested sampling rate of ~1%.
   EXPECT_NEAR((end_size - start_size) / static_cast<double>(tables.size()),
               0.01, 0.005);
-  if (soo_enabled) {
-    EXPECT_EQ(observed_checksums.size(), 9);
-  } else {
-    EXPECT_EQ(observed_checksums.size(), 5);
-    for (const auto& [_, count] : observed_checksums) {
-      EXPECT_NEAR((100 * count) / static_cast<double>(tables.size()), 0.2,
-                  0.05);
-    }
+  EXPECT_EQ(observed_checksums.size(), 5);
+  for (const auto& [_, count] : observed_checksums) {
+    EXPECT_NEAR((100 * count) / static_cast<double>(tables.size()), 0.2, 0.05);
   }
 
   EXPECT_EQ(reservations.size(), 10);
@@ -2591,12 +2600,141 @@
 REGISTER_TYPED_TEST_SUITE_P(RawHashSamplerTest, Sample);
 using RawHashSamplerTestTypes = ::testing::Types<SooIntTable, NonSooIntTable>;
 INSTANTIATE_TYPED_TEST_SUITE_P(My, RawHashSamplerTest, RawHashSamplerTestTypes);
+
+std::vector<const HashtablezInfo*> SampleSooMutation(
+    absl::FunctionRef<void(SooIntTable&)> mutate_table) {
+  // Enable the feature even if the prod default is off.
+  SetHashtablezEnabled(true);
+  SetHashtablezSampleParameter(100);  // Sample ~1% of tables.
+
+  auto& sampler = GlobalHashtablezSampler();
+  size_t start_size = 0;
+  absl::flat_hash_set<const HashtablezInfo*> preexisting_info;
+  start_size += sampler.Iterate([&](const HashtablezInfo& info) {
+    preexisting_info.insert(&info);
+    ++start_size;
+  });
+
+  std::vector<SooIntTable> tables;
+  for (int i = 0; i < 1000000; ++i) {
+    tables.emplace_back();
+    mutate_table(tables.back());
+  }
+  size_t end_size = 0;
+  std::vector<const HashtablezInfo*> infos;
+  end_size += sampler.Iterate([&](const HashtablezInfo& info) {
+    ++end_size;
+    if (preexisting_info.contains(&info)) return;
+    infos.push_back(&info);
+  });
+
+  // Expect that we sampled at the requested sampling rate of ~1%.
+  EXPECT_NEAR((end_size - start_size) / static_cast<double>(tables.size()),
+              0.01, 0.005);
+  return infos;
+}
+
+TEST(RawHashSamplerTest, SooTableInsertToEmpty) {
+  if (SooIntTable().capacity() != SooCapacity()) {
+    CHECK_LT(sizeof(void*), 8) << "missing SOO coverage";
+    GTEST_SKIP() << "not SOO on this platform";
+  }
+  std::vector<const HashtablezInfo*> infos =
+      SampleSooMutation([](SooIntTable& t) { t.insert(1); });
+
+  for (const HashtablezInfo* info : infos) {
+    ASSERT_EQ(info->inline_element_size,
+              sizeof(typename SooIntTable::value_type));
+    ASSERT_EQ(info->soo_capacity, SooCapacity());
+    ASSERT_EQ(info->capacity, NextCapacity(SooCapacity()));
+    ASSERT_EQ(info->size, 1);
+    ASSERT_EQ(info->max_reserve, 0);
+    ASSERT_EQ(info->num_erases, 0);
+    ASSERT_EQ(info->max_probe_length, 0);
+    ASSERT_EQ(info->total_probe_length, 0);
+  }
+}
+
+TEST(RawHashSamplerTest, SooTableReserveToEmpty) {
+  if (SooIntTable().capacity() != SooCapacity()) {
+    CHECK_LT(sizeof(void*), 8) << "missing SOO coverage";
+    GTEST_SKIP() << "not SOO on this platform";
+  }
+  std::vector<const HashtablezInfo*> infos =
+      SampleSooMutation([](SooIntTable& t) { t.reserve(100); });
+
+  for (const HashtablezInfo* info : infos) {
+    ASSERT_EQ(info->inline_element_size,
+              sizeof(typename SooIntTable::value_type));
+    ASSERT_EQ(info->soo_capacity, SooCapacity());
+    ASSERT_GE(info->capacity, 100);
+    ASSERT_EQ(info->size, 0);
+    ASSERT_EQ(info->max_reserve, 100);
+    ASSERT_EQ(info->num_erases, 0);
+    ASSERT_EQ(info->max_probe_length, 0);
+    ASSERT_EQ(info->total_probe_length, 0);
+  }
+}
+
+// This tests that reserve on a full SOO table doesn't incorrectly result in new
+// (over-)sampling.
+TEST(RawHashSamplerTest, SooTableReserveToFullSoo) {
+  if (SooIntTable().capacity() != SooCapacity()) {
+    CHECK_LT(sizeof(void*), 8) << "missing SOO coverage";
+    GTEST_SKIP() << "not SOO on this platform";
+  }
+  std::vector<const HashtablezInfo*> infos =
+      SampleSooMutation([](SooIntTable& t) {
+        t.insert(1);
+        t.reserve(100);
+      });
+
+  for (const HashtablezInfo* info : infos) {
+    ASSERT_EQ(info->inline_element_size,
+              sizeof(typename SooIntTable::value_type));
+    ASSERT_EQ(info->soo_capacity, SooCapacity());
+    ASSERT_GE(info->capacity, 100);
+    ASSERT_EQ(info->size, 1);
+    ASSERT_EQ(info->max_reserve, 100);
+    ASSERT_EQ(info->num_erases, 0);
+    ASSERT_EQ(info->max_probe_length, 0);
+    ASSERT_EQ(info->total_probe_length, 0);
+  }
+}
+
+// This tests that rehash(0) on a sampled table with size that fits in SOO
+// doesn't incorrectly result in losing sampling.
+TEST(RawHashSamplerTest, SooTableRehashShrinkWhenSizeFitsInSoo) {
+  if (SooIntTable().capacity() != SooCapacity()) {
+    CHECK_LT(sizeof(void*), 8) << "missing SOO coverage";
+    GTEST_SKIP() << "not SOO on this platform";
+  }
+  std::vector<const HashtablezInfo*> infos =
+      SampleSooMutation([](SooIntTable& t) {
+        t.reserve(100);
+        t.insert(1);
+        EXPECT_GE(t.capacity(), 100);
+        t.rehash(0);
+      });
+
+  for (const HashtablezInfo* info : infos) {
+    ASSERT_EQ(info->inline_element_size,
+              sizeof(typename SooIntTable::value_type));
+    ASSERT_EQ(info->soo_capacity, SooCapacity());
+    ASSERT_EQ(info->capacity, NextCapacity(SooCapacity()));
+    ASSERT_EQ(info->size, 1);
+    ASSERT_EQ(info->max_reserve, 100);
+    ASSERT_EQ(info->num_erases, 0);
+    ASSERT_EQ(info->max_probe_length, 0);
+    ASSERT_EQ(info->total_probe_length, 0);
+  }
+}
 #endif  // ABSL_INTERNAL_HASHTABLEZ_SAMPLE
 
 TEST(RawHashSamplerTest, DoNotSampleCustomAllocators) {
   // Enable the feature even if the prod default is off.
   SetHashtablezEnabled(true);
-  SetHashtablezSampleParameter(100);
+  SetHashtablezSampleParameter(100);  // Sample ~1% of tables.
 
   auto& sampler = GlobalHashtablezSampler();
   size_t start_size = 0;
@@ -2974,7 +3112,7 @@
   bool frozen = true;
   TypeParam t{FreezableAlloc<typename TypeParam::value_type>(&frozen)};
   if (t.capacity() != SooCapacity()) {
-    EXPECT_LT(sizeof(void*), 8);
+    CHECK_LT(sizeof(void*), 8) << "missing SOO coverage";
     GTEST_SKIP() << "not SOO on this platform";
   }
 
@@ -3006,14 +3144,18 @@
                      FreezableSizedValueSooTable<16>>;
 INSTANTIATE_TYPED_TEST_SUITE_P(My, SooTable, FreezableSooTableTypes);
 
-TEST(Table, RehashToSoo) {
+TEST(Table, RehashToSooUnsampled) {
   SooIntTable t;
   if (t.capacity() != SooCapacity()) {
-    EXPECT_LT(sizeof(void*), 8);
+    CHECK_LT(sizeof(void*), 8) << "missing SOO coverage";
     GTEST_SKIP() << "not SOO on this platform";
   }
 
-  t.reserve(10);
+  // We disable hashtablez sampling for this test to ensure that the table isn't
+  // sampled. When the table is sampled, it won't rehash down to SOO.
+  SetHashtablezEnabled(false);
+
+  t.reserve(100);
   t.insert(0);
   EXPECT_EQ(*t.begin(), 0);
 
@@ -3026,7 +3168,7 @@
   EXPECT_EQ(t.find(1), t.end());
 }
 
-TEST(Table, ResizeToNonSoo) {
+TEST(Table, ReserveToNonSoo) {
   for (int reserve_capacity : {8, 100000}) {
     SooIntTable t;
     t.insert(0);