Refactor the object cache to better account for race conditions (#13204)

Supersedes  #13075

Closes #13204

COPYBARA_INTEGRATE_REVIEW=https://github.com/protocolbuffers/protobuf/pull/13204 from fowles:main 9b6aad673d96f92a3e99d1e48f9e42fd43f6791c
PiperOrigin-RevId: 546579459
diff --git a/.gitignore b/.gitignore
index df64e36..f25b4f7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -58,6 +58,7 @@
 
 # vim generated
 *.swp
+*~
 
 # Generated test scaffolding
 src/no_warning_test.cc
diff --git a/ruby/ext/google/protobuf_c/defs.c b/ruby/ext/google/protobuf_c/defs.c
index 1035942..f5fe225 100644
--- a/ruby/ext/google/protobuf_c/defs.c
+++ b/ruby/ext/google/protobuf_c/defs.c
@@ -129,9 +129,7 @@
 
   RB_OBJ_WRITE(ret, &self->def_to_descriptor, rb_hash_new());
   self->symtab = upb_DefPool_New();
-  ObjectCache_Add(self->symtab, ret);
-
-  return ret;
+  return ObjectCache_TryAdd(self->symtab, ret);
 }
 
 /*
diff --git a/ruby/ext/google/protobuf_c/map.c b/ruby/ext/google/protobuf_c/map.c
index 18617ed..1c9fa20 100644
--- a/ruby/ext/google/protobuf_c/map.c
+++ b/ruby/ext/google/protobuf_c/map.c
@@ -93,7 +93,6 @@
   if (val == Qnil) {
     val = Map_alloc(cMap);
     Map* self;
-    ObjectCache_Add(map, val);
     TypedData_Get_Struct(val, Map, &Map_type, self);
     self->map = map;
     self->arena = arena;
@@ -103,6 +102,7 @@
       const upb_MessageDef* val_m = self->value_type_info.def.msgdef;
       self->value_type_class = Descriptor_DefToClass(val_m);
     }
+    return ObjectCache_TryAdd(map, val);
   }
 
   return val;
@@ -319,7 +319,9 @@
 
   self->map = upb_Map_New(Arena_get(self->arena), self->key_type,
                           self->value_type_info.type);
-  ObjectCache_Add(self->map, _self);
+  VALUE stored = ObjectCache_TryAdd(self->map, _self);
+  (void)stored;
+  PBRUBY_ASSERT(stored == _self);
 
   if (init_arg != Qnil) {
     Map_merge_into_self(_self, init_arg);
diff --git a/ruby/ext/google/protobuf_c/message.c b/ruby/ext/google/protobuf_c/message.c
index 8208cb6..4d5a488 100644
--- a/ruby/ext/google/protobuf_c/message.c
+++ b/ruby/ext/google/protobuf_c/message.c
@@ -108,7 +108,9 @@
   Message* self = ruby_to_Message(self_);
   self->msg = msg;
   RB_OBJ_WRITE(self_, &self->arena, arena);
-  ObjectCache_Add(msg, self_);
+  VALUE stored = ObjectCache_TryAdd(msg, self_);
+  (void)stored;
+  PBRUBY_ASSERT(stored == self_);
 }
 
 VALUE Message_GetArena(VALUE msg_rb) {
diff --git a/ruby/ext/google/protobuf_c/protobuf.c b/ruby/ext/google/protobuf_c/protobuf.c
index bee8ac2..46f6436 100644
--- a/ruby/ext/google/protobuf_c/protobuf.c
+++ b/ruby/ext/google/protobuf_c/protobuf.c
@@ -195,7 +195,8 @@
     .flags = RUBY_TYPED_FREE_IMMEDIATELY | RUBY_TYPED_WB_PROTECTED,
 };
 
-static void* ruby_upb_allocfunc(upb_alloc* alloc, void* ptr, size_t oldsize, size_t size) {
+static void *ruby_upb_allocfunc(upb_alloc *alloc, void *ptr, size_t oldsize,
+                                size_t size) {
   if (size == 0) {
     xfree(ptr);
     return NULL;
@@ -252,164 +253,40 @@
 // Object Cache
 // -----------------------------------------------------------------------------
 
-// A pointer -> Ruby Object cache that keeps references to Ruby wrapper
-// objects.  This allows us to look up any Ruby wrapper object by the address
-// of the object it is wrapping. That way we can avoid ever creating two
-// different wrapper objects for the same C object, which saves memory and
-// preserves object identity.
-//
-// We use WeakMap for the cache. For Ruby <2.7 we also need a secondary Hash
-// to store WeakMap keys because Ruby <2.7 WeakMap doesn't allow non-finalizable
-// keys.
-//
-// We also need the secondary Hash if sizeof(long) < sizeof(VALUE), because this
-// means it may not be possible to fit a pointer into a Fixnum. Keys are
-// pointers, and if they fit into a Fixnum, Ruby doesn't collect them, but if
-// they overflow and require allocating a Bignum, they could get collected
-// prematurely, thus removing the cache entry. This happens on 64-bit Windows,
-// on which pointers are 64 bits but longs are 32 bits. In this case, we enable
-// the secondary Hash to hold the keys and prevent them from being collected.
-
-#if RUBY_API_VERSION_CODE >= 20700 && SIZEOF_LONG >= SIZEOF_VALUE
-#define USE_SECONDARY_MAP 0
-#else
-#define USE_SECONDARY_MAP 1
-#endif
-
-#if USE_SECONDARY_MAP
-
-// Maps Numeric -> Object. The object is then used as a key into the WeakMap.
-// This is needed for Ruby <2.7 where a number cannot be a key to WeakMap.
-// The object is used only for its identity; it does not contain any data.
-VALUE secondary_map = Qnil;
-
-// Mutations to the map are under a mutex, because SeconaryMap_MaybeGC()
-// iterates over the map which cannot happen in parallel with insertions, or
-// Ruby will throw:
-//   can't add a new key into hash during iteration (RuntimeError)
-VALUE secondary_map_mutex = Qnil;
-
-// Lambda that will GC entries from the secondary map that are no longer present
-// in the primary map.
-VALUE gc_secondary_map_lambda = Qnil;
-ID length;
-
-extern VALUE weak_obj_cache;
-
-static void SecondaryMap_Init() {
-  rb_gc_register_address(&secondary_map);
-  rb_gc_register_address(&gc_secondary_map_lambda);
-  rb_gc_register_address(&secondary_map_mutex);
-  secondary_map = rb_hash_new();
-  gc_secondary_map_lambda = rb_eval_string(
-      "->(secondary, weak) {\n"
-      "  secondary.delete_if { |k, v| !weak.key?(v) }\n"
-      "}\n");
-  secondary_map_mutex = rb_mutex_new();
-  length = rb_intern("length");
-}
-
-// The secondary map is a regular Hash, and will never shrink on its own.
-// The main object cache is a WeakMap that will automatically remove entries
-// when the target object is no longer reachable, but unless we manually
-// remove the corresponding entries from the secondary map, it will grow
-// without bound.
-//
-// To avoid this unbounded growth we periodically remove entries from the
-// secondary map that are no longer present in the WeakMap. The logic of
-// how often to perform this GC is an artbirary tuning parameter that
-// represents a straightforward CPU/memory tradeoff.
-//
-// Requires: secondary_map_mutex is held.
-static void SecondaryMap_MaybeGC() {
-  PBRUBY_ASSERT(rb_mutex_locked_p(secondary_map_mutex) == Qtrue);
-  size_t weak_len = NUM2ULL(rb_funcall(weak_obj_cache, length, 0));
-  size_t secondary_len = RHASH_SIZE(secondary_map);
-  if (secondary_len < weak_len) {
-    // Logically this case should not be possible: a valid entry cannot exist in
-    // the weak table unless there is a corresponding entry in the secondary
-    // table. It should *always* be the case that secondary_len >= weak_len.
-    //
-    // However ObjectSpace::WeakMap#length (and therefore weak_len) is
-    // unreliable: it overreports its true length by including non-live objects.
-    // However these non-live objects are not yielded in iteration, so we may
-    // have previously deleted them from the secondary map in a previous
-    // invocation of SecondaryMap_MaybeGC().
-    //
-    // In this case, we can't measure any waste, so we just return.
-    return;
-  }
-  size_t waste = secondary_len - weak_len;
-  // GC if we could remove at least 2000 entries or 20% of the table size
-  // (whichever is greater).  Since the cost of the GC pass is O(N), we
-  // want to make sure that we condition this on overall table size, to
-  // avoid O(N^2) CPU costs.
-  size_t threshold = PBRUBY_MAX(secondary_len * 0.2, 2000);
-  if (waste > threshold) {
-    rb_funcall(gc_secondary_map_lambda, rb_intern("call"), 2, secondary_map,
-               weak_obj_cache);
-  }
-}
-
-// Requires: secondary_map_mutex is held by this thread iff create == true.
-static VALUE SecondaryMap_Get(VALUE key, bool create) {
-  PBRUBY_ASSERT(!create || rb_mutex_locked_p(secondary_map_mutex) == Qtrue);
-  VALUE ret = rb_hash_lookup(secondary_map, key);
-  if (ret == Qnil && create) {
-    SecondaryMap_MaybeGC();
-    ret = rb_class_new_instance(0, NULL, rb_cObject);
-    rb_hash_aset(secondary_map, key, ret);
-  }
-  return ret;
-}
-
-#endif
-
-// Requires: secondary_map_mutex is held by this thread iff create == true.
-static VALUE ObjectCache_GetKey(const void *key, bool create) {
-  VALUE key_val = (VALUE)key;
-  PBRUBY_ASSERT((key_val & 3) == 0);
-  VALUE ret = LL2NUM(key_val >> 2);
-#if USE_SECONDARY_MAP
-  ret = SecondaryMap_Get(ret, create);
-#endif
-  return ret;
-}
-
 // Public ObjectCache API.
 
 VALUE weak_obj_cache = Qnil;
 ID item_get;
-ID item_set;
+ID item_try_add;
 
-static void ObjectCache_Init() {
+static void ObjectCache_Init(VALUE protobuf) {
+  item_get = rb_intern("get");
+  item_try_add = rb_intern("try_add");
+
   rb_gc_register_address(&weak_obj_cache);
-  VALUE klass = rb_eval_string("ObjectSpace::WeakMap");
-  weak_obj_cache = rb_class_new_instance(0, NULL, klass);
-  item_get = rb_intern("[]");
-  item_set = rb_intern("[]=");
-#if USE_SECONDARY_MAP
-  SecondaryMap_Init();
+#if RUBY_API_VERSION_CODE >= 20700 && SIZEOF_LONG >= SIZEOF_VALUE
+  VALUE cache_class = rb_const_get(protobuf, rb_intern("ObjectCache"));
+#else
+  VALUE cache_class = rb_const_get(protobuf, rb_intern("LegacyObjectCache"));
 #endif
+
+  weak_obj_cache = rb_class_new_instance(0, NULL, cache_class);
+  rb_const_set(protobuf, rb_intern("OBJECT_CACHE"), weak_obj_cache);
+  rb_const_set(protobuf, rb_intern("SIZEOF_LONG"), INT2NUM(SIZEOF_LONG));
+  rb_const_set(protobuf, rb_intern("SIZEOF_VALUE"), INT2NUM(SIZEOF_VALUE));
 }
 
-void ObjectCache_Add(const void *key, VALUE val) {
-  PBRUBY_ASSERT(ObjectCache_Get(key) == Qnil);
-#if USE_SECONDARY_MAP
-  rb_mutex_lock(secondary_map_mutex);
-#endif
-  VALUE key_rb = ObjectCache_GetKey(key, true);
-  rb_funcall(weak_obj_cache, item_set, 2, key_rb, val);
-#if USE_SECONDARY_MAP
-  rb_mutex_unlock(secondary_map_mutex);
-#endif
-  PBRUBY_ASSERT(ObjectCache_Get(key) == val);
+VALUE ObjectCache_TryAdd(const void *key, VALUE val) {
+  VALUE key_val = (VALUE)key;
+  PBRUBY_ASSERT((key_val & 3) == 0);
+  return rb_funcall(weak_obj_cache, item_try_add, 2, LL2NUM(key_val), val);
 }
 
 // Returns the cached object for this key, if any. Otherwise returns Qnil.
 VALUE ObjectCache_Get(const void *key) {
-  VALUE key_rb = ObjectCache_GetKey(key, false);
-  return rb_funcall(weak_obj_cache, item_get, 1, key_rb);
+  VALUE key_val = (VALUE)key;
+  PBRUBY_ASSERT((key_val & 3) == 0);
+  return rb_funcall(weak_obj_cache, item_get, 1, LL2NUM(key_val));
 }
 
 /*
@@ -459,11 +336,10 @@
 // This must be named "Init_protobuf_c" because the Ruby module is named
 // "protobuf_c" -- the VM looks for this symbol in our .so.
 __attribute__((visibility("default"))) void Init_protobuf_c() {
-  ObjectCache_Init();
-
   VALUE google = rb_define_module("Google");
   VALUE protobuf = rb_define_module_under(google, "Protobuf");
 
+  ObjectCache_Init(protobuf);
   Arena_register(protobuf);
   Defs_register(protobuf);
   RepeatedField_register(protobuf);
diff --git a/ruby/ext/google/protobuf_c/protobuf.h b/ruby/ext/google/protobuf_c/protobuf.h
index a51d9a7..a27bdb8 100644
--- a/ruby/ext/google/protobuf_c/protobuf.h
+++ b/ruby/ext/google/protobuf_c/protobuf.h
@@ -84,10 +84,9 @@
 // being collected (though in Ruby <2.7 is it effectively strong, due to
 // implementation limitations).
 
-// Adds an entry to the cache. The "arena" parameter must give the arena that
-// "key" was allocated from.  In Ruby <2.7.0, it will be used to remove the key
-// from the cache when the arena is destroyed.
-void ObjectCache_Add(const void* key, VALUE val);
+// Tries to add a new entry to the cache, returning the newly installed value or
+// the pre-existing entry.
+VALUE ObjectCache_TryAdd(const void* key, VALUE val);
 
 // Returns the cached object for this key, if any. Otherwise returns Qnil.
 VALUE ObjectCache_Get(const void* key);
diff --git a/ruby/ext/google/protobuf_c/repeated_field.c b/ruby/ext/google/protobuf_c/repeated_field.c
index bcd5f45..cf54681 100644
--- a/ruby/ext/google/protobuf_c/repeated_field.c
+++ b/ruby/ext/google/protobuf_c/repeated_field.c
@@ -87,7 +87,6 @@
   if (val == Qnil) {
     val = RepeatedField_alloc(cRepeatedField);
     RepeatedField* self;
-    ObjectCache_Add(array, val);
     TypedData_Get_Struct(val, RepeatedField, &RepeatedField_type, self);
     self->array = array;
     self->arena = arena;
@@ -95,6 +94,7 @@
     if (self->type_info.type == kUpb_CType_Message) {
       self->type_class = Descriptor_DefToClass(type_info.def.msgdef);
     }
+    val = ObjectCache_TryAdd(array, val);
   }
 
   PBRUBY_ASSERT(ruby_to_RepeatedField(val)->type_info.type == type_info.type);
@@ -615,7 +615,8 @@
 
   self->type_info = TypeInfo_FromClass(argc, argv, 0, &self->type_class, &ary);
   self->array = upb_Array_New(arena, self->type_info.type);
-  ObjectCache_Add(self->array, _self);
+  VALUE stored_val = ObjectCache_TryAdd(self->array, _self);
+  PBRUBY_ASSERT(stored_val == _self);
 
   if (ary != Qnil) {
     if (!RB_TYPE_P(ary, T_ARRAY)) {
diff --git a/ruby/lib/google/protobuf.rb b/ruby/lib/google/protobuf.rb
index b7a6711..07985f6 100644
--- a/ruby/lib/google/protobuf.rb
+++ b/ruby/lib/google/protobuf.rb
@@ -30,6 +30,7 @@
 
 # require mixins before we hook them into the java & c code
 require 'google/protobuf/message_exts'
+require 'google/protobuf/object_cache'
 
 # We define these before requiring the platform-specific modules.
 # That way the module init can grab references to these.
diff --git a/ruby/lib/google/protobuf/object_cache.rb b/ruby/lib/google/protobuf/object_cache.rb
new file mode 100644
index 0000000..583cabf
--- /dev/null
+++ b/ruby/lib/google/protobuf/object_cache.rb
@@ -0,0 +1,124 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2023 Google Inc.  All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+#     * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#     * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+#     * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
+module Google
+  module Protobuf
+    # A pointer -> Ruby Object cache that keeps references to Ruby wrapper
+    # objects.  This allows us to look up any Ruby wrapper object by the address
+    # of the object it is wrapping. That way we can avoid ever creating two
+    # different wrapper objects for the same C object, which saves memory and
+    # preserves object identity.
+    #
+    # We use WeakMap for the cache. For Ruby <2.7 we also need a secondary Hash
+    # to store WeakMap keys because Ruby <2.7 WeakMap doesn't allow non-finalizable
+    # keys.
+    #
+    # We also need the secondary Hash if sizeof(long) < sizeof(VALUE), because this
+    # means it may not be possible to fit a pointer into a Fixnum. Keys are
+    # pointers, and if they fit into a Fixnum, Ruby doesn't collect them, but if
+    # they overflow and require allocating a Bignum, they could get collected
+    # prematurely, thus removing the cache entry. This happens on 64-bit Windows,
+    # on which pointers are 64 bits but longs are 32 bits. In this case, we enable
+    # the secondary Hash to hold the keys and prevent them from being collected.
+    class ObjectCache
+      def initialize
+        @map = ObjectSpace::WeakMap.new
+        @mutex = Mutex.new
+      end
+
+      def get(key)
+        @map[key]
+      end
+
+      def try_add(key, value)
+        @map[key] || @mutex.synchronize do
+          @map[key] ||= value
+        end
+      end
+    end
+
+    class LegacyObjectCache
+      def initialize
+        @secondary_map = {}
+        @map = ObjectSpace::WeakMap.new
+        @mutex = Mutex.new
+      end
+
+      def get(key)
+        value = if secondary_key = @secondary_map[key]
+          @map[secondary_key]
+        else
+          @mutex.synchronize do
+            @map[(@secondary_map[key] ||= Object.new)]
+          end
+        end
+
+        # GC if we could remove at least 2000 entries or 20% of the table size
+        # (whichever is greater).  Since the cost of the GC pass is O(N), we
+        # want to make sure that we condition this on overall table size, to
+        # avoid O(N^2) CPU costs.
+        cutoff = (@secondary_map.size * 0.2).ceil
+        cutoff = 2_000 if cutoff < 2_000
+        if (@secondary_map.size - @map.size) > cutoff
+          purge
+        end
+
+        value
+      end
+
+      def try_add(key, value)
+        if secondary_key = @secondary_map[key]
+          if old_value = @map[secondary_key]
+            return old_value
+          end
+        end
+
+        @mutex.synchronize do
+          secondary_key ||= (@secondary_map[key] ||= Object.new)
+          @map[secondary_key] ||= value
+        end
+      end
+
+      private
+
+      def purge
+        @mutex.synchronize do
+          @secondary_map.each do |key, secondary_key|
+            unless @map.key?(secondary_key)
+              @secondary_map.delete(key)
+            end
+          end
+        end
+        nil
+      end
+    end
+  end
+end
+
diff --git a/ruby/tests/BUILD.bazel b/ruby/tests/BUILD.bazel
index 9ec6634..c072cc7 100644
--- a/ruby/tests/BUILD.bazel
+++ b/ruby/tests/BUILD.bazel
@@ -77,6 +77,16 @@
 )
 
 ruby_test(
+    name = "object_cache_test",
+    srcs = ["object_cache_test.rb"],
+    deps = [
+        "//ruby:protobuf",
+        "//ruby:test_ruby_protos",
+        "@protobuf_bundle//:test-unit",
+    ],
+)
+
+ruby_test(
     name = "repeated_field_test",
     srcs = ["repeated_field_test.rb"],
     deps = [
diff --git a/ruby/tests/object_cache_test.rb b/ruby/tests/object_cache_test.rb
new file mode 100644
index 0000000..d4e6b49
--- /dev/null
+++ b/ruby/tests/object_cache_test.rb
@@ -0,0 +1,70 @@
+require 'google/protobuf'
+require 'test/unit'
+
+class PlatformTest < Test::Unit::TestCase
+  def test_correct_implementation_for_platform
+    if RUBY_PLATFORM == "java"
+      return
+    end
+
+    if Gem::Version.new(RUBY_VERSION) >= Gem::Version.new('2.7.0') and Google::Protobuf::SIZEOF_LONG >= Google::Protobuf::SIZEOF_VALUE
+      assert_instance_of Google::Protobuf::ObjectCache, Google::Protobuf::OBJECT_CACHE
+    else
+      assert_instance_of Google::Protobuf::LegacyObjectCache, Google::Protobuf::OBJECT_CACHE
+    end
+  end
+end
+
+module ObjectCacheTestModule
+  def test_try_add_returns_existing_value
+    cache = self.create
+
+    keys = ["k1", "k2"]
+    vals = ["v1", "v2"]
+    2.times do |i|
+      assert_same vals[i], cache.try_add(keys[i], vals[i])
+      assert_same vals[i], cache.get(keys[i])
+      assert_same vals[i], cache.try_add(keys[i], vals[(i+1)%2])
+      assert_same vals[i], cache.get(keys[i])
+    end
+  end
+
+  def test_multithreaded_access
+    cache = self.create
+
+    keys = ["k0", "k1", "k2", "k3", "k4", "k5", "k6", "k7"]
+
+    threads = []
+    100.times do |i|
+      threads[i] = Thread.new {
+        Thread.current["result"] = cache.try_add(keys[i % keys.size], i.to_s)
+      }
+    end
+
+    results = {}
+    threads.each_with_index {|t, i| t.join; results[i] = t["result"] }
+    assert_equal 100, results.size
+    assert_equal 8, results.values.uniq.size
+  end
+end
+
+if RUBY_PLATFORM == "java"
+  return
+end
+
+class ObjectCacheTest < Test::Unit::TestCase
+  def create
+    Google::Protobuf::ObjectCache.new
+  end
+
+  include ObjectCacheTestModule
+end
+
+class LegacyObjectCacheTest < Test::Unit::TestCase
+  def create
+    Google::Protobuf::LegacyObjectCache.new
+  end
+
+  include ObjectCacheTestModule
+end
+