Type erased hash_slot_fn that depends only on key types (and hash function).

PiperOrigin-RevId: 603148301
Change-Id: Ie2e5702995c9e1ef4d5aaab23bc89a1eb5007a86
diff --git a/absl/container/BUILD.bazel b/absl/container/BUILD.bazel
index 9163394..5160cce 100644
--- a/absl/container/BUILD.bazel
+++ b/absl/container/BUILD.bazel
@@ -357,6 +357,7 @@
     copts = ABSL_DEFAULT_COPTS,
     linkopts = ABSL_DEFAULT_LINKOPTS,
     deps = [
+        ":container_memory",
         ":hash_function_defaults",
         ":node_slot_policy",
         ":raw_hash_set",
@@ -504,6 +505,7 @@
     copts = ABSL_TEST_COPTS,
     linkopts = ABSL_DEFAULT_LINKOPTS,
     deps = [
+        ":container_memory",
         ":hash_policy_traits",
         "@com_google_googletest//:gtest",
         "@com_google_googletest//:gtest_main",
@@ -714,6 +716,7 @@
     tags = ["benchmark"],
     visibility = ["//visibility:private"],
     deps = [
+        ":container_memory",
         ":hash_function_defaults",
         ":raw_hash_set",
         "//absl/base:raw_logging_internal",
@@ -753,6 +756,7 @@
     copts = ABSL_TEST_COPTS,
     linkopts = ABSL_DEFAULT_LINKOPTS,
     deps = [
+        ":container_memory",
         ":raw_hash_set",
         ":tracked",
         "//absl/base:config",
diff --git a/absl/container/CMakeLists.txt b/absl/container/CMakeLists.txt
index 3f80d29..85af0bd 100644
--- a/absl/container/CMakeLists.txt
+++ b/absl/container/CMakeLists.txt
@@ -401,6 +401,7 @@
   COPTS
     ${ABSL_DEFAULT_COPTS}
   DEPS
+    absl::container_memory
     absl::core_headers
     absl::hash_function_defaults
     absl::node_slot_policy
@@ -560,6 +561,7 @@
   COPTS
     ${ABSL_TEST_COPTS}
   DEPS
+    absl::container_memory
     absl::hash_policy_traits
     GTest::gmock_main
 )
@@ -776,6 +778,7 @@
     ${ABSL_TEST_COPTS}
   DEPS
     absl::config
+    absl::container_memory
     absl::raw_hash_set
     absl::tracked
     GTest::gmock_main
diff --git a/absl/container/flat_hash_map.h b/absl/container/flat_hash_map.h
index 808a62b..2a1ad0f 100644
--- a/absl/container/flat_hash_map.h
+++ b/absl/container/flat_hash_map.h
@@ -593,6 +593,13 @@
                                                    std::forward<Args>(args)...);
   }
 
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return memory_internal::IsLayoutCompatible<K, V>::value
+               ? &TypeErasedApplyToSlotFn<Hash, K>
+               : nullptr;
+  }
+
   static size_t space_used(const slot_type*) { return 0; }
 
   static std::pair<const K, V>& element(slot_type* slot) { return slot->value; }
diff --git a/absl/container/flat_hash_set.h b/absl/container/flat_hash_set.h
index cd74923..1be0019 100644
--- a/absl/container/flat_hash_set.h
+++ b/absl/container/flat_hash_set.h
@@ -29,6 +29,7 @@
 #ifndef ABSL_CONTAINER_FLAT_HASH_SET_H_
 #define ABSL_CONTAINER_FLAT_HASH_SET_H_
 
+#include <cstddef>
 #include <memory>
 #include <type_traits>
 #include <utility>
@@ -492,6 +493,11 @@
   }
 
   static size_t space_used(const T*) { return 0; }
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return &TypeErasedApplyToSlotFn<Hash, T>;
+  }
 };
 }  // namespace container_internal
 
diff --git a/absl/container/internal/container_memory.h b/absl/container/internal/container_memory.h
index d86d7b8..ba8e08a 100644
--- a/absl/container/internal/container_memory.h
+++ b/absl/container/internal/container_memory.h
@@ -464,6 +464,26 @@
   }
 };
 
+// Type erased function for computing hash of the slot.
+using HashSlotFn = size_t (*)(const void* hash_fn, void* slot);
+
+// Type erased function to apply `Fn` to data inside of the `slot`.
+// The data is expected to have type `T`.
+template <class Fn, class T>
+size_t TypeErasedApplyToSlotFn(const void* fn, void* slot) {
+  const auto* f = static_cast<const Fn*>(fn);
+  return (*f)(*static_cast<const T*>(slot));
+}
+
+// Type erased function to apply `Fn` to data inside of the `*slot_ptr`.
+// The data is expected to have type `T`.
+template <class Fn, class T>
+size_t TypeErasedDerefAndApplyToSlotFn(const void* fn, void* slot_ptr) {
+  const auto* f = static_cast<const Fn*>(fn);
+  const T* slot = *static_cast<const T**>(slot_ptr);
+  return (*f)(*slot);
+}
+
 }  // namespace container_internal
 ABSL_NAMESPACE_END
 }  // namespace absl
diff --git a/absl/container/internal/container_memory_test.cc b/absl/container/internal/container_memory_test.cc
index c973ac2..7e4357d 100644
--- a/absl/container/internal/container_memory_test.cc
+++ b/absl/container/internal/container_memory_test.cc
@@ -298,6 +298,20 @@
   }
 }
 
+TEST(ApplyTest, TypeErasedApplyToSlotFn) {
+  size_t x = 7;
+  auto fn = [](size_t v) { return v * 2; };
+  EXPECT_EQ((TypeErasedApplyToSlotFn<decltype(fn), size_t>(&fn, &x)), 14);
+}
+
+TEST(ApplyTest, TypeErasedDerefAndApplyToSlotFn) {
+  size_t x = 7;
+  auto fn = [](size_t v) { return v * 2; };
+  size_t* x_ptr = &x;
+  EXPECT_EQ(
+      (TypeErasedDerefAndApplyToSlotFn<decltype(fn), size_t>(&fn, &x_ptr)), 14);
+}
+
 }  // namespace
 }  // namespace container_internal
 ABSL_NAMESPACE_END
diff --git a/absl/container/internal/hash_policy_traits.h b/absl/container/internal/hash_policy_traits.h
index 164ec12..86ffd1b 100644
--- a/absl/container/internal/hash_policy_traits.h
+++ b/absl/container/internal/hash_policy_traits.h
@@ -148,6 +148,41 @@
   static auto value(T* elem) -> decltype(P::value(elem)) {
     return P::value(elem);
   }
+
+  using HashSlotFn = size_t (*)(const void* hash_fn, void* slot);
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+// get_hash_slot_fn may return nullptr to signal that non type erased function
+// should be used. GCC warns against comparing function address with nullptr.
+#if defined(__GNUC__) && !defined(__clang__)
+#pragma GCC diagnostic push
+// silent error: the address of * will never be NULL [-Werror=address]
+#pragma GCC diagnostic ignored "-Waddress"
+#endif
+    return Policy::template get_hash_slot_fn<Hash>() == nullptr
+               ? &hash_slot_fn_non_type_erased<Hash>
+               : Policy::template get_hash_slot_fn<Hash>();
+#if defined(__GNUC__) && !defined(__clang__)
+#pragma GCC diagnostic pop
+#endif
+  }
+
+ private:
+  template <class Hash>
+  struct HashElement {
+    template <class K, class... Args>
+    size_t operator()(const K& key, Args&&...) const {
+      return h(key);
+    }
+    const Hash& h;
+  };
+
+  template <class Hash>
+  static size_t hash_slot_fn_non_type_erased(const void* hash_fn, void* slot) {
+    return Policy::apply(HashElement<Hash>{*static_cast<const Hash*>(hash_fn)},
+                         Policy::element(static_cast<slot_type*>(slot)));
+  }
 };
 
 }  // namespace container_internal
diff --git a/absl/container/internal/hash_policy_traits_test.cc b/absl/container/internal/hash_policy_traits_test.cc
index 82d7cc3..2d2c7c2 100644
--- a/absl/container/internal/hash_policy_traits_test.cc
+++ b/absl/container/internal/hash_policy_traits_test.cc
@@ -14,12 +14,14 @@
 
 #include "absl/container/internal/hash_policy_traits.h"
 
+#include <cstddef>
 #include <functional>
 #include <memory>
 #include <new>
 
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
+#include "absl/container/internal/container_memory.h"
 
 namespace absl {
 ABSL_NAMESPACE_BEGIN
@@ -42,6 +44,11 @@
   static int apply(int v) { return apply_impl(v); }
   static std::function<int(int)> apply_impl;
   static std::function<Slot&(Slot*)> value;
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return nullptr;
+  }
 };
 
 std::function<int(int)> PolicyWithoutOptionalOps::apply_impl;
@@ -74,6 +81,63 @@
   EXPECT_EQ(&b, &hash_policy_traits<PolicyWithoutOptionalOps>::value(&a));
 }
 
+struct Hash {
+  size_t operator()(Slot a) const { return static_cast<size_t>(a) * 5; }
+};
+
+struct PolicyNoHashFn {
+  using slot_type = Slot;
+  using key_type = Slot;
+  using init_type = Slot;
+
+  static size_t* apply_called_count;
+
+  static Slot& element(Slot* slot) { return *slot; }
+  template <typename Fn>
+  static size_t apply(const Fn& fn, int v) {
+    ++(*apply_called_count);
+    return fn(v);
+  }
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return nullptr;
+  }
+};
+
+size_t* PolicyNoHashFn::apply_called_count;
+
+struct PolicyCustomHashFn : PolicyNoHashFn {
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return &TypeErasedApplyToSlotFn<Hash, int>;
+  }
+};
+
+TEST(HashTest, PolicyNoHashFn_get_hash_slot_fn) {
+  size_t apply_called_count = 0;
+  PolicyNoHashFn::apply_called_count = &apply_called_count;
+
+  Hash hasher;
+  Slot value = 7;
+  auto* fn = hash_policy_traits<PolicyNoHashFn>::get_hash_slot_fn<Hash>();
+  EXPECT_NE(fn, nullptr);
+  EXPECT_EQ(fn(&hasher, &value), hasher(value));
+  EXPECT_EQ(apply_called_count, 1);
+}
+
+TEST(HashTest, PolicyCustomHashFn_get_hash_slot_fn) {
+  size_t apply_called_count = 0;
+  PolicyNoHashFn::apply_called_count = &apply_called_count;
+
+  Hash hasher;
+  Slot value = 7;
+  auto* fn = hash_policy_traits<PolicyCustomHashFn>::get_hash_slot_fn<Hash>();
+  EXPECT_EQ(fn, PolicyCustomHashFn::get_hash_slot_fn<Hash>());
+  EXPECT_EQ(fn(&hasher, &value), hasher(value));
+  EXPECT_EQ(apply_called_count, 0);
+}
+
 }  // namespace
 }  // namespace container_internal
 ABSL_NAMESPACE_END
diff --git a/absl/container/internal/raw_hash_set.cc b/absl/container/internal/raw_hash_set.cc
index 9f8ea51..02301e1 100644
--- a/absl/container/internal/raw_hash_set.cc
+++ b/absl/container/internal/raw_hash_set.cc
@@ -140,7 +140,7 @@
   return reinterpret_cast<void*>(reinterpret_cast<uintptr_t>(slot) - slot_size);
 }
 
-void DropDeletesWithoutResize(CommonFields& common,
+void DropDeletesWithoutResize(CommonFields& common, const void* hash_fn,
                               const PolicyFunctions& policy, void* tmp_space) {
   void* set = &common;
   void* slot_array = common.slot_array();
@@ -175,7 +175,7 @@
        ++i, slot_ptr = NextSlot(slot_ptr, slot_size)) {
     assert(slot_ptr == SlotAddress(slot_array, i, slot_size));
     if (!IsDeleted(ctrl[i])) continue;
-    const size_t hash = (*hasher)(set, slot_ptr);
+    const size_t hash = (*hasher)(hash_fn, slot_ptr);
     const FindInfo target = find_first_non_full(common, hash);
     const size_t new_i = target.offset;
     total_probe_length += target.probe_length;
diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h
index c464de6..3b79382 100644
--- a/absl/container/internal/raw_hash_set.h
+++ b/absl/container/internal/raw_hash_set.h
@@ -1802,7 +1802,7 @@
   size_t slot_size;
 
   // Returns the hash of the pointed-to slot.
-  size_t (*hash_slot)(void* set, void* slot);
+  size_t (*hash_slot)(const void* hash_fn, void* slot);
 
   // Transfer the contents of src_slot to dst_slot.
   void (*transfer)(void* set, void* dst_slot, void* src_slot);
@@ -1847,7 +1847,7 @@
 }
 
 // Type-erased version of raw_hash_set::drop_deletes_without_resize.
-void DropDeletesWithoutResize(CommonFields& common,
+void DropDeletesWithoutResize(CommonFields& common, const void* hash_fn,
                               const PolicyFunctions& policy, void* tmp_space);
 
 // A SwissTable.
@@ -2989,7 +2989,7 @@
   inline void drop_deletes_without_resize() {
     // Stack-allocate space for swapping elements.
     alignas(slot_type) unsigned char tmp[sizeof(slot_type)];
-    DropDeletesWithoutResize(common(), GetPolicyFunctions(), tmp);
+    DropDeletesWithoutResize(common(), &hash_ref(), GetPolicyFunctions(), tmp);
   }
 
   // Called whenever the table *might* need to conditionally grow.
@@ -3238,13 +3238,6 @@
     return settings_.template get<3>();
   }
 
-  // Make type-specific functions for this type's PolicyFunctions struct.
-  static size_t hash_slot_fn(void* set, void* slot) {
-    auto* h = static_cast<raw_hash_set*>(set);
-    return PolicyTraits::apply(
-        HashElement{h->hash_ref()},
-        PolicyTraits::element(static_cast<slot_type*>(slot)));
-  }
   static void transfer_slot_fn(void* set, void* dst, void* src) {
     auto* h = static_cast<raw_hash_set*>(set);
     h->transfer(static_cast<slot_type*>(dst), static_cast<slot_type*>(src));
@@ -3266,7 +3259,7 @@
   static const PolicyFunctions& GetPolicyFunctions() {
     static constexpr PolicyFunctions value = {
         sizeof(slot_type),
-        &raw_hash_set::hash_slot_fn,
+        PolicyTraits::template get_hash_slot_fn<hasher>(),
         PolicyTraits::transfer_uses_memcpy()
             ? TransferRelocatable<sizeof(slot_type)>
             : &raw_hash_set::transfer_slot_fn,
diff --git a/absl/container/internal/raw_hash_set_allocator_test.cc b/absl/container/internal/raw_hash_set_allocator_test.cc
index 05dcfaa..7e7a506 100644
--- a/absl/container/internal/raw_hash_set_allocator_test.cc
+++ b/absl/container/internal/raw_hash_set_allocator_test.cc
@@ -25,6 +25,7 @@
 #include "gmock/gmock.h"
 #include "gtest/gtest.h"
 #include "absl/base/config.h"
+#include "absl/container/internal/container_memory.h"
 #include "absl/container/internal/raw_hash_set.h"
 #include "absl/container/internal/tracked.h"
 
@@ -133,7 +134,7 @@
 };
 
 struct Identity {
-  int32_t operator()(int32_t v) const { return v; }
+  size_t operator()(int32_t v) const { return static_cast<size_t>(v); }
 };
 
 struct Policy {
@@ -178,6 +179,11 @@
   }
 
   static slot_type& element(slot_type* slot) { return *slot; }
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return nullptr;
+  }
 };
 
 template <int Spec>
diff --git a/absl/container/internal/raw_hash_set_benchmark.cc b/absl/container/internal/raw_hash_set_benchmark.cc
index 05a0642..bc8184d 100644
--- a/absl/container/internal/raw_hash_set_benchmark.cc
+++ b/absl/container/internal/raw_hash_set_benchmark.cc
@@ -24,6 +24,7 @@
 #include <vector>
 
 #include "absl/base/internal/raw_logging.h"
+#include "absl/container/internal/container_memory.h"
 #include "absl/container/internal/hash_function_defaults.h"
 #include "absl/container/internal/raw_hash_set.h"
 #include "absl/strings/str_format.h"
@@ -59,6 +60,11 @@
   static auto apply(F&& f, int64_t x) -> decltype(std::forward<F>(f)(x, x)) {
     return std::forward<F>(f)(x, x);
   }
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return nullptr;
+  }
 };
 
 class StringPolicy {
@@ -117,6 +123,11 @@
     return apply_impl(std::forward<F>(f),
                       PairArgs(std::forward<Args>(args)...));
   }
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return nullptr;
+  }
 };
 
 struct StringHash : container_internal::hash_default_hash<absl::string_view> {
diff --git a/absl/container/internal/raw_hash_set_probe_benchmark.cc b/absl/container/internal/raw_hash_set_probe_benchmark.cc
index 5d4184b..8f36305 100644
--- a/absl/container/internal/raw_hash_set_probe_benchmark.cc
+++ b/absl/container/internal/raw_hash_set_probe_benchmark.cc
@@ -70,6 +70,11 @@
       -> decltype(std::forward<F>(f)(arg, arg)) {
     return std::forward<F>(f)(arg, arg);
   }
+
+  template <class Hash>
+  static constexpr auto get_hash_slot_fn() {
+    return nullptr;
+  }
 };
 
 absl::BitGen& GlobalBitGen() {
diff --git a/absl/container/internal/raw_hash_set_test.cc b/absl/container/internal/raw_hash_set_test.cc
index 6baaa06..7ec72b2 100644
--- a/absl/container/internal/raw_hash_set_test.cc
+++ b/absl/container/internal/raw_hash_set_test.cc
@@ -405,6 +405,11 @@
     return absl::container_internal::DecomposeValue(
         std::forward<F>(f), std::forward<Args>(args)...);
   }
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return nullptr;
+  }
 };
 
 using IntPolicy = ValuePolicy<int64_t>;
@@ -468,6 +473,11 @@
     return apply_impl(std::forward<F>(f),
                       PairArgs(std::forward<Args>(args)...));
   }
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return nullptr;
+  }
 };
 
 struct StringHash : absl::Hash<absl::string_view> {
@@ -923,6 +933,11 @@
   static auto apply(F&& f, const T& x) -> decltype(std::forward<F>(f)(x, x)) {
     return std::forward<F>(f)(x, x);
   }
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return nullptr;
+  }
 };
 
 template <typename Hash, typename Eq>
@@ -1083,7 +1098,7 @@
 }
 
 struct Modulo1000Hash {
-  size_t operator()(int x) const { return x % 1000; }
+  size_t operator()(int64_t x) const { return static_cast<size_t>(x) % 1000; }
 };
 
 struct Modulo1000HashTable
@@ -2664,7 +2679,7 @@
 TEST(Table, AllocatorPropagation) { TestAllocPropagation<RawHashSetAlloc>(); }
 
 struct CountedHash {
-  size_t operator()(int value) const {
+  size_t operator()(int64_t value) const {
     ++count;
     return static_cast<size_t>(value);
   }
diff --git a/absl/container/node_hash_map.h b/absl/container/node_hash_map.h
index a396de2..72e7895 100644
--- a/absl/container/node_hash_map.h
+++ b/absl/container/node_hash_map.h
@@ -36,6 +36,7 @@
 #ifndef ABSL_CONTAINER_NODE_HASH_MAP_H_
 #define ABSL_CONTAINER_NODE_HASH_MAP_H_
 
+#include <cstddef>
 #include <tuple>
 #include <type_traits>
 #include <utility>
@@ -590,6 +591,13 @@
 
   static Value& value(value_type* elem) { return elem->second; }
   static const Value& value(const value_type* elem) { return elem->second; }
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return memory_internal::IsLayoutCompatible<Key, Value>::value
+               ? &TypeErasedDerefAndApplyToSlotFn<Hash, Key>
+               : nullptr;
+  }
 };
 }  // namespace container_internal
 
diff --git a/absl/container/node_hash_set.h b/absl/container/node_hash_set.h
index 421ff46..ec9ab51 100644
--- a/absl/container/node_hash_set.h
+++ b/absl/container/node_hash_set.h
@@ -35,10 +35,12 @@
 #ifndef ABSL_CONTAINER_NODE_HASH_SET_H_
 #define ABSL_CONTAINER_NODE_HASH_SET_H_
 
+#include <cstddef>
 #include <type_traits>
 
 #include "absl/algorithm/container.h"
 #include "absl/base/macros.h"
+#include "absl/container/internal/container_memory.h"
 #include "absl/container/internal/hash_function_defaults.h"  // IWYU pragma: export
 #include "absl/container/internal/node_slot_policy.h"
 #include "absl/container/internal/raw_hash_set.h"  // IWYU pragma: export
@@ -487,6 +489,11 @@
   }
 
   static size_t element_space_used(const T*) { return sizeof(T); }
+
+  template <class Hash>
+  static constexpr HashSlotFn get_hash_slot_fn() {
+    return &TypeErasedDerefAndApplyToSlotFn<Hash, T>;
+  }
 };
 }  // namespace container_internal