Add sanitizer mode validation for use of references to swisstables elements that may have been invalidated by a container move.

PiperOrigin-RevId: 578649798
Change-Id: Icfee98d3a0399b545ec6ec59c5b52ae5e006218b
diff --git a/absl/container/internal/raw_hash_set.cc b/absl/container/internal/raw_hash_set.cc
index df64e7e..6e5941d 100644
--- a/absl/container/internal/raw_hash_set.cc
+++ b/absl/container/internal/raw_hash_set.cc
@@ -67,6 +67,16 @@
   return value ^ static_cast<size_t>(reinterpret_cast<uintptr_t>(&counter));
 }
 
+bool ShouldRehashForBugDetection(const ctrl_t* ctrl, size_t capacity) {
+  // Note: we can't use the abseil-random library because abseil-random
+  // depends on swisstable. We want to return true with probability
+  // `min(1, RehashProbabilityConstant() / capacity())`. In order to do this,
+  // we probe based on a random hash and see if the offset is less than
+  // RehashProbabilityConstant().
+  return probe(ctrl, capacity, absl::HashOf(RandomSeed())).offset() <
+         RehashProbabilityConstant();
+}
+
 }  // namespace
 
 GenerationType* EmptyGeneration() {
@@ -84,13 +94,12 @@
                                               size_t capacity) const {
   if (reserved_growth_ == kReservedGrowthJustRanOut) return true;
   if (reserved_growth_ > 0) return false;
-  // Note: we can't use the abseil-random library because abseil-random
-  // depends on swisstable. We want to return true with probability
-  // `min(1, RehashProbabilityConstant() / capacity())`. In order to do this,
-  // we probe based on a random hash and see if the offset is less than
-  // RehashProbabilityConstant().
-  return probe(ctrl, capacity, absl::HashOf(RandomSeed())).offset() <
-         RehashProbabilityConstant();
+  return ShouldRehashForBugDetection(ctrl, capacity);
+}
+
+bool CommonFieldsGenerationInfoEnabled::should_rehash_for_bug_detection_on_move(
+    const ctrl_t* ctrl, size_t capacity) const {
+  return ShouldRehashForBugDetection(ctrl, capacity);
 }
 
 bool ShouldInsertBackwards(size_t hash, const ctrl_t* ctrl) {
diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h
index 67964f5..f702f6a 100644
--- a/absl/container/internal/raw_hash_set.h
+++ b/absl/container/internal/raw_hash_set.h
@@ -840,6 +840,9 @@
   // whenever reserved_growth_ is zero.
   bool should_rehash_for_bug_detection_on_insert(const ctrl_t* ctrl,
                                                  size_t capacity) const;
+  // Similar to above, except that we don't depend on reserved_growth_.
+  bool should_rehash_for_bug_detection_on_move(const ctrl_t* ctrl,
+                                               size_t capacity) const;
   void maybe_increment_generation_on_insert() {
     if (reserved_growth_ == kReservedGrowthJustRanOut) reserved_growth_ = 0;
 
@@ -895,6 +898,9 @@
   bool should_rehash_for_bug_detection_on_insert(const ctrl_t*, size_t) const {
     return false;
   }
+  bool should_rehash_for_bug_detection_on_move(const ctrl_t*, size_t) const {
+    return false;
+  }
   void maybe_increment_generation_on_insert() {}
   void increment_generation() {}
   void reset_reserved_growth(size_t, size_t) {}
@@ -1068,6 +1074,10 @@
     return CommonFieldsGenerationInfo::
         should_rehash_for_bug_detection_on_insert(control(), capacity());
   }
+  bool should_rehash_for_bug_detection_on_move() const {
+    return CommonFieldsGenerationInfo::
+        should_rehash_for_bug_detection_on_move(control(), capacity());
+  }
   void maybe_increment_generation_on_move() {
     if (capacity() == 0) return;
     increment_generation();
@@ -1932,17 +1942,17 @@
         settings_(std::move(that.common()), that.hash_ref(), that.eq_ref(),
                   that.alloc_ref()) {
     that.common() = CommonFields{};
-    common().maybe_increment_generation_on_move();
+    maybe_increment_generation_or_rehash_on_move();
   }
 
   raw_hash_set(raw_hash_set&& that, const allocator_type& a)
       : settings_(CommonFields{}, that.hash_ref(), that.eq_ref(), a) {
     if (a == that.alloc_ref()) {
       std::swap(common(), that.common());
+      maybe_increment_generation_or_rehash_on_move();
     } else {
       move_elements_allocs_unequal(std::move(that));
     }
-    common().maybe_increment_generation_on_move();
   }
 
   raw_hash_set& operator=(const raw_hash_set& that) {
@@ -2720,6 +2730,13 @@
     }
   }
 
+  void maybe_increment_generation_or_rehash_on_move() {
+    common().maybe_increment_generation_on_move();
+    if (!empty() && common().should_rehash_for_bug_detection_on_move()) {
+      resize(capacity());
+    }
+  }
+
   template<bool propagate_alloc>
   raw_hash_set& assign_impl(raw_hash_set&& that) {
     // We don't bother checking for this/that aliasing. We just need to avoid
@@ -2732,7 +2749,7 @@
     CopyAlloc(alloc_ref(), that.alloc_ref(),
               std::integral_constant<bool, propagate_alloc>());
     that.common() = CommonFields{};
-    common().maybe_increment_generation_on_move();
+    maybe_increment_generation_or_rehash_on_move();
     return *this;
   }
 
@@ -2746,6 +2763,7 @@
     }
     that.dealloc();
     that.common() = CommonFields{};
+    maybe_increment_generation_or_rehash_on_move();
     return *this;
   }
 
@@ -2767,7 +2785,6 @@
     // TODO(b/296061262): move instead of copying hash/eq.
     hash_ref() = that.hash_ref();
     eq_ref() = that.eq_ref();
-    common().maybe_increment_generation_on_move();
     return move_elements_allocs_unequal(std::move(that));
   }
 
diff --git a/absl/container/internal/raw_hash_set_allocator_test.cc b/absl/container/internal/raw_hash_set_allocator_test.cc
index 0bb61cf..05dcfaa 100644
--- a/absl/container/internal/raw_hash_set_allocator_test.cc
+++ b/absl/container/internal/raw_hash_set_allocator_test.cc
@@ -22,6 +22,7 @@
 #include <type_traits>
 #include <utility>
 
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 #include "absl/base/config.h"
 #include "absl/container/internal/raw_hash_set.h"
@@ -31,6 +32,7 @@
 ABSL_NAMESPACE_BEGIN
 namespace container_internal {
 namespace {
+using ::testing::AnyOf;
 
 enum AllocSpec {
   kPropagateOnCopy = 1,
@@ -287,8 +289,8 @@
   t1.insert(0);
   Table u(std::move(t1));
   auto it = u.begin();
-  EXPECT_EQ(1, a1.num_allocs());
-  EXPECT_EQ(0, it->num_moves());
+  EXPECT_THAT(a1.num_allocs(), AnyOf(1, 2));
+  EXPECT_THAT(it->num_moves(), AnyOf(0, 1));
   EXPECT_EQ(0, it->num_copies());
 }
 
@@ -296,8 +298,8 @@
   t1.insert(0);
   Table u(std::move(t1));
   auto it = u.begin();
-  EXPECT_EQ(1, a1.num_allocs());
-  EXPECT_EQ(0, it->num_moves());
+  EXPECT_THAT(a1.num_allocs(), AnyOf(1, 2));
+  EXPECT_THAT(it->num_moves(), AnyOf(0, 1));
   EXPECT_EQ(0, it->num_copies());
 }
 
@@ -305,8 +307,8 @@
   t1.insert(0);
   Table u(std::move(t1), a1);
   auto it = u.begin();
-  EXPECT_EQ(1, a1.num_allocs());
-  EXPECT_EQ(0, it->num_moves());
+  EXPECT_THAT(a1.num_allocs(), AnyOf(1, 2));
+  EXPECT_THAT(it->num_moves(), AnyOf(0, 1));
   EXPECT_EQ(0, it->num_copies());
 }
 
@@ -314,8 +316,8 @@
   t1.insert(0);
   Table u(std::move(t1), a1);
   auto it = u.begin();
-  EXPECT_EQ(1, a1.num_allocs());
-  EXPECT_EQ(0, it->num_moves());
+  EXPECT_THAT(a1.num_allocs(), AnyOf(1, 2));
+  EXPECT_THAT(it->num_moves(), AnyOf(0, 1));
   EXPECT_EQ(0, it->num_copies());
 }
 
@@ -325,8 +327,8 @@
   it = u.find(0);
   EXPECT_EQ(a2, u.get_allocator());
   EXPECT_EQ(1, a1.num_allocs());
-  EXPECT_EQ(1, a2.num_allocs());
-  EXPECT_EQ(1, it->num_moves());
+  EXPECT_THAT(a2.num_allocs(), AnyOf(1, 2));
+  EXPECT_THAT(it->num_moves(), AnyOf(1, 2));
   EXPECT_EQ(0, it->num_copies());
 }
 
@@ -336,8 +338,8 @@
   it = u.find(0);
   EXPECT_EQ(a2, u.get_allocator());
   EXPECT_EQ(1, a1.num_allocs());
-  EXPECT_EQ(1, a2.num_allocs());
-  EXPECT_EQ(1, it->num_moves());
+  EXPECT_THAT(a2.num_allocs(), AnyOf(1, 2));
+  EXPECT_THAT(it->num_moves(), AnyOf(1, 2));
   EXPECT_EQ(0, it->num_copies());
 }
 
@@ -345,8 +347,8 @@
   auto it = t1.insert(0).first;
   Table u(0, a1);
   u = t1;
-  EXPECT_EQ(2, a1.num_allocs());
-  EXPECT_EQ(0, it->num_moves());
+  EXPECT_THAT(a1.num_allocs(), AnyOf(2, 3));
+  EXPECT_THAT(it->num_moves(), AnyOf(0, 1));
   EXPECT_EQ(1, it->num_copies());
 }
 
@@ -354,8 +356,8 @@
   auto it = t1.insert(0).first;
   Table u(0, a1);
   u = t1;
-  EXPECT_EQ(2, a1.num_allocs());
-  EXPECT_EQ(0, it->num_moves());
+  EXPECT_THAT(a1.num_allocs(), AnyOf(2, 3));
+  EXPECT_THAT(it->num_moves(), AnyOf(0, 1));
   EXPECT_EQ(1, it->num_copies());
 }
 
@@ -364,9 +366,9 @@
   Table u(0, a2);
   u = t1;
   EXPECT_EQ(a1, u.get_allocator());
-  EXPECT_EQ(2, a1.num_allocs());
+  EXPECT_THAT(a1.num_allocs(), AnyOf(2, 3));
   EXPECT_EQ(0, a2.num_allocs());
-  EXPECT_EQ(0, it->num_moves());
+  EXPECT_THAT(it->num_moves(), AnyOf(0, 1));
   EXPECT_EQ(1, it->num_copies());
 }
 
@@ -376,8 +378,8 @@
   u = t1;
   EXPECT_EQ(a2, u.get_allocator());
   EXPECT_EQ(1, a1.num_allocs());
-  EXPECT_EQ(1, a2.num_allocs());
-  EXPECT_EQ(0, it->num_moves());
+  EXPECT_THAT(a2.num_allocs(), AnyOf(1, 2));
+  EXPECT_THAT(it->num_moves(), AnyOf(0, 1));
   EXPECT_EQ(1, it->num_copies());
 }
 
@@ -387,8 +389,8 @@
   u = std::move(t1);
   auto it = u.begin();
   EXPECT_EQ(a1, u.get_allocator());
-  EXPECT_EQ(1, a1.num_allocs());
-  EXPECT_EQ(0, it->num_moves());
+  EXPECT_THAT(a1.num_allocs(), AnyOf(1, 2));
+  EXPECT_THAT(it->num_moves(), AnyOf(0, 1));
   EXPECT_EQ(0, it->num_copies());
 }
 
@@ -398,8 +400,8 @@
   u = std::move(t1);
   auto it = u.begin();
   EXPECT_EQ(a1, u.get_allocator());
-  EXPECT_EQ(1, a1.num_allocs());
-  EXPECT_EQ(0, it->num_moves());
+  EXPECT_THAT(a1.num_allocs(), AnyOf(1, 2));
+  EXPECT_THAT(it->num_moves(), AnyOf(0, 1));
   EXPECT_EQ(0, it->num_copies());
 }
 
@@ -409,9 +411,9 @@
   u = std::move(t1);
   auto it = u.begin();
   EXPECT_EQ(a1, u.get_allocator());
-  EXPECT_EQ(1, a1.num_allocs());
+  EXPECT_THAT(a1.num_allocs(), AnyOf(1, 2));
   EXPECT_EQ(0, a2.num_allocs());
-  EXPECT_EQ(0, it->num_moves());
+  EXPECT_THAT(it->num_moves(), AnyOf(0, 1));
   EXPECT_EQ(0, it->num_copies());
 }
 
@@ -422,8 +424,8 @@
   auto it = u.find(0);
   EXPECT_EQ(a2, u.get_allocator());
   EXPECT_EQ(1, a1.num_allocs());
-  EXPECT_EQ(1, a2.num_allocs());
-  EXPECT_EQ(1, it->num_moves());
+  EXPECT_THAT(a2.num_allocs(), AnyOf(1, 2));
+  EXPECT_THAT(it->num_moves(), AnyOf(1, 2));
   EXPECT_EQ(0, it->num_copies());
 }
 
diff --git a/absl/container/internal/raw_hash_set_test.cc b/absl/container/internal/raw_hash_set_test.cc
index 4e67b79..4937e6f 100644
--- a/absl/container/internal/raw_hash_set_test.cc
+++ b/absl/container/internal/raw_hash_set_test.cc
@@ -2375,10 +2375,17 @@
   IntTable t1, t2;
   t1.insert(1);
   auto it = t1.begin();
+  // ptr will become invalidated on rehash.
+  const int64_t* ptr = &*it;
+  (void)ptr;
+
   t2 = std::move(t1);
   EXPECT_DEATH_IF_SUPPORTED(*it, kInvalidIteratorDeathMessage);
   EXPECT_DEATH_IF_SUPPORTED(void(it == t2.begin()),
                             kInvalidIteratorDeathMessage);
+#ifdef ABSL_HAVE_ADDRESS_SANITIZER
+  EXPECT_DEATH_IF_SUPPORTED(std::cout << *ptr, "heap-use-after-free");
+#endif
 }
 
 TEST(Table, ReservedGrowthUpdatesWhenTableDoesntGrow) {