Implement `btree::iterator::+=` and `-=`. PiperOrigin-RevId: 735878054 Change-Id: I37e0c89b66f5e31376e007dda8d4420a6dfe5269
diff --git a/absl/container/btree_benchmark.cc b/absl/container/btree_benchmark.cc index 0d26fd4..d0dac37 100644 --- a/absl/container/btree_benchmark.cc +++ b/absl/container/btree_benchmark.cc
@@ -735,7 +735,7 @@ BIG_TYPE_PTR_BENCHMARKS(32); -void BM_BtreeSet_IteratorSubtraction(benchmark::State& state) { +void BM_BtreeSet_IteratorDifference(benchmark::State& state) { absl::InsecureBitGen bitgen; std::vector<int> vec; // Randomize the set's insertion order so the nodes aren't all full. @@ -756,6 +756,52 @@ } } +BENCHMARK(BM_BtreeSet_IteratorDifference)->Range(1 << 10, 1 << 20); + +void BM_BtreeSet_IteratorAddition(benchmark::State& state) { + absl::InsecureBitGen bitgen; + std::vector<int> vec; + // Randomize the set's insertion order so the nodes aren't all full. + vec.reserve(static_cast<size_t>(state.range(0))); + for (int i = 0; i < state.range(0); ++i) vec.push_back(i); + absl::c_shuffle(vec, bitgen); + + absl::btree_set<int> set; + for (int i : vec) set.insert(i); + + size_t distance = absl::Uniform(bitgen, 0u, set.size()); + while (state.KeepRunningBatch(distance)) { + // Let the increment go all the way to the `end` iterator. + const size_t begin = + absl::Uniform(absl::IntervalClosed, bitgen, 0u, set.size() - distance); + auto it = set.find(static_cast<int>(begin)); + benchmark::DoNotOptimize(it += static_cast<int>(distance)); + distance = absl::Uniform(bitgen, 0u, set.size()); + } +} + +BENCHMARK(BM_BtreeSet_IteratorAddition)->Range(1 << 10, 1 << 20); + +void BM_BtreeSet_IteratorSubtraction(benchmark::State& state) { + absl::InsecureBitGen bitgen; + std::vector<int> vec; + // Randomize the set's insertion order so the nodes aren't all full. + vec.reserve(static_cast<size_t>(state.range(0))); + for (int i = 0; i < state.range(0); ++i) vec.push_back(i); + absl::c_shuffle(vec, bitgen); + + absl::btree_set<int> set; + for (int i : vec) set.insert(i); + + size_t distance = absl::Uniform(bitgen, 0u, set.size()); + while (state.KeepRunningBatch(distance)) { + size_t end = absl::Uniform(bitgen, distance, set.size()); + auto it = set.find(static_cast<int>(end)); + benchmark::DoNotOptimize(it -= static_cast<int>(distance)); + distance = absl::Uniform(bitgen, 0u, set.size()); + } +} + BENCHMARK(BM_BtreeSet_IteratorSubtraction)->Range(1 << 10, 1 << 20); } // namespace
diff --git a/absl/container/btree_map.h b/absl/container/btree_map.h index 470de2a..32a82ef 100644 --- a/absl/container/btree_map.h +++ b/absl/container/btree_map.h
@@ -47,8 +47,10 @@ // iterator at the current position. Another important difference is that // key-types must be copy-constructible. // -// Another API difference is that btree iterators can be subtracted, and this -// is faster than using std::distance. +// There are other API differences: first, btree iterators can be subtracted, +// and this is faster than using `std::distance`. Additionally, btree +// iterators can be advanced via `operator+=` and `operator-=`, which is faster +// than using `std::advance`. // // B-tree maps are not exception-safe.
diff --git a/absl/container/btree_set.h b/absl/container/btree_set.h index e57d6d9..16181de 100644 --- a/absl/container/btree_set.h +++ b/absl/container/btree_set.h
@@ -46,8 +46,10 @@ // reason, `insert()`, `erase()`, and `extract_and_get_next()` return a valid // iterator at the current position. // -// Another API difference is that btree iterators can be subtracted, and this -// is faster than using std::distance. +// There are other API differences: first, btree iterators can be subtracted, +// and this is faster than using `std::distance`. Additionally, btree +// iterators can be advanced via `operator+=` and `operator-=`, which is faster +// than using `std::advance`. // // B-tree sets are not exception-safe.
diff --git a/absl/container/btree_test.cc b/absl/container/btree_test.cc index c398922..1d2c2a6 100644 --- a/absl/container/btree_test.cc +++ b/absl/container/btree_test.cc
@@ -3351,7 +3351,7 @@ set.insert(0); } -TEST(Btree, IteratorSubtraction) { +TEST(Btree, IteratorDifference) { absl::BitGen bitgen; std::vector<int> vec; // Randomize the set's insertion order so the nodes aren't all full. @@ -3369,6 +3369,94 @@ } } +TEST(Btree, IteratorAddition) { + absl::BitGen bitgen; + std::vector<int> vec; + + // Randomize the set's insertion order so the nodes aren't all full. + constexpr int kSetSize = 1000000; + for (int i = 0; i < kSetSize; ++i) vec.push_back(i); + absl::c_shuffle(vec, bitgen); + + absl::btree_set<int> set; + for (int i : vec) set.insert(i); + + for (int i = 0; i < 1000; ++i) { + int begin = absl::Uniform(bitgen, 0, kSetSize); + int end = absl::Uniform(bitgen, begin, kSetSize); + ASSERT_LE(begin, end); + + auto it = set.find(begin); + it += end - begin; + ASSERT_EQ(it, set.find(end)) << end; + + it += begin - end; + ASSERT_EQ(it, set.find(begin)) << begin; + } +} + +TEST(Btree, IteratorAdditionOutOfBounds) { + const absl::btree_set<int> set({5}); + + auto it = set.find(5); + + auto forward = it; + forward += 1; + EXPECT_EQ(forward, set.end()); + + auto backward = it; + EXPECT_EQ(backward, set.begin()); + + if (IsAssertEnabled()) { + EXPECT_DEATH(forward += 1, "n == 0"); + EXPECT_DEATH(backward += -1, "position >= node->start"); + } +} + +TEST(Btree, IteratorSubtraction) { + absl::BitGen bitgen; + std::vector<int> vec; + + // Randomize the set's insertion order so the nodes aren't all full. + constexpr int kSetSize = 1000000; + for (int i = 0; i < kSetSize; ++i) vec.push_back(i); + absl::c_shuffle(vec, bitgen); + + absl::btree_set<int> set; + for (int i : vec) set.insert(i); + + for (int i = 0; i < 1000; ++i) { + int begin = absl::Uniform(bitgen, 0, kSetSize); + int end = absl::Uniform(bitgen, begin, kSetSize); + ASSERT_LE(begin, end); + + auto it = set.find(end); + it -= end - begin; + ASSERT_EQ(it, set.find(begin)) << begin; + + it -= begin - end; + ASSERT_EQ(it, set.find(end)) << end; + } +} + +TEST(Btree, IteratorSubtractionOutOfBounds) { + const absl::btree_set<int> set({5}); + + auto it = set.find(5); + + auto backward = it; + EXPECT_EQ(backward, set.begin()); + + auto forward = it; + forward -= -1; + EXPECT_EQ(forward, set.end()); + + if (IsAssertEnabled()) { + EXPECT_DEATH(backward -= 1, "position >= node->start"); + EXPECT_DEATH(forward -= -1, "n == 0"); + } +} + TEST(Btree, DereferencingEndIterator) { if (!IsAssertEnabled()) GTEST_SKIP() << "Assertions not enabled.";
diff --git a/absl/container/internal/btree.h b/absl/container/internal/btree.h index 689e71a..a742829 100644 --- a/absl/container/internal/btree.h +++ b/absl/container/internal/btree.h
@@ -60,6 +60,7 @@ #include "absl/base/config.h" #include "absl/base/internal/raw_logging.h" #include "absl/base/macros.h" +#include "absl/base/optimization.h" #include "absl/container/internal/common.h" #include "absl/container/internal/common_policy_traits.h" #include "absl/container/internal/compressed_tuple.h" @@ -708,6 +709,8 @@ } // Getter for the parent of this node. + // TODO(ezb): assert that the child of the returned node at position + // `node_->position()` maps to the current node. btree_node *parent() const { return *GetField<0>(); } // Getter for whether the node is the root of the tree. The parent of the // root of the tree is the leftmost node in the tree which is guaranteed to @@ -1175,6 +1178,26 @@ return distance_slow(other); } + // Advances the iterator by `n`. Values of `n` must not result in going past + // the `end` iterator (for a positive `n`) or before the `begin` iterator (for + // a negative `n`). + btree_iterator &operator+=(difference_type n) { + assert_valid_generation(node_); + if (n == 0) return *this; + if (n < 0) return decrement_n_slow(-n); + return increment_n_slow(n); + } + + // Moves the iterator by `n` positions backwards. Values of `n` must not + // result in going before the `begin` iterator (for a positive `n`) or past + // the `end` iterator (for a negative `n`). + btree_iterator &operator-=(difference_type n) { + assert_valid_generation(node_); + if (n == 0) return *this; + if (n < 0) return increment_n_slow(-n); + return decrement_n_slow(n); + } + // Accessors for the key/value the iterator is pointing at. reference operator*() const { ABSL_HARDENING_ASSERT(node_ != nullptr); @@ -1277,6 +1300,7 @@ increment_slow(); } void increment_slow(); + btree_iterator &increment_n_slow(difference_type n); void decrement() { assert_valid_generation(node_); @@ -1286,6 +1310,7 @@ decrement_slow(); } void decrement_slow(); + btree_iterator &decrement_n_slow(difference_type n); const key_type &key() const { return node_->key(static_cast<size_type>(position_)); @@ -2172,6 +2197,80 @@ } } +template <typename N, typename R, typename P> +btree_iterator<N, R, P> &btree_iterator<N, R, P>::increment_n_slow( + difference_type n) { + N *node = node_; + int position = position_; + ABSL_ASSUME(n > 0); + while (n > 0) { + if (node->is_leaf()) { + if (position + n < node->finish()) { + position += n; + break; + } else { + n -= node->finish() - position; + position = node->finish(); + btree_iterator save = {node, position}; + while (position == node->finish() && !node->is_root()) { + position = node->position(); + node = node->parent(); + } + if (position == node->finish()) { + ABSL_HARDENING_ASSERT(n == 0); + return *this = save; + } + } + } else { + --n; + assert(position < node->finish()); + node = node->child(static_cast<field_type>(position + 1)); + while (node->is_internal()) { + node = node->start_child(); + } + position = node->start(); + } + } + node_ = node; + position_ = position; + return *this; +} + +template <typename N, typename R, typename P> +btree_iterator<N, R, P> &btree_iterator<N, R, P>::decrement_n_slow( + difference_type n) { + N *node = node_; + int position = position_; + ABSL_ASSUME(n > 0); + while (n > 0) { + if (node->is_leaf()) { + if (position - n >= node->start()) { + position -= n; + break; + } else { + n -= 1 + position - node->start(); + position = node->start() - 1; + while (position < node->start() && !node->is_root()) { + position = node->position() - 1; + node = node->parent(); + } + ABSL_HARDENING_ASSERT(position >= node->start()); + } + } else { + --n; + assert(position >= node->start()); + node = node->child(static_cast<field_type>(position)); + while (node->is_internal()) { + node = node->child(node->finish()); + } + position = node->finish() - 1; + } + } + node_ = node; + position_ = position; + return *this; +} + //// // btree methods template <typename P>