blob: 2447598460b9fe861209b828c99eba9336962243 [file] [log] [blame]
// Copyright 2025 The BoringSSL Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "merkle_tree.h"
#include <cassert>
#include <cstdint>
#include <limits>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <openssl/sha2.h>
BSSL_NAMESPACE_BEGIN
namespace {
class MerkleTree {
public:
class Data {
public:
virtual ~Data() = default;
virtual std::vector<uint8_t> At(uint64_t index) = 0;
// Caching functions are only called for full subtrees (size is a power of
// 2) that have size at least 2.
virtual std::optional<TreeHash> NodeHash(Subtree node) = 0;
virtual void CacheNodeHash(Subtree node, TreeHash hash) = 0;
};
explicit MerkleTree(Data *data) : data_(data) {}
MerkleTree(const MerkleTree &) = delete;
MerkleTree(MerkleTree &&) = default;
MerkleTree &operator=(const MerkleTree &) = delete;
MerkleTree &operator=(MerkleTree &&) = default;
// MTH computes the Merkle Tree Hash (MTH; RFC 9162, section 2.1.1) of a
// tree containing `end - start` elements from D_n, starting at d[start]. If
// start > end or there is an internal error, this function returns nullopt.
//
// Note that the MTH function defined in RFC 9162 takes an ordered list of
// inputs D_n. This function takes start and end indicies to identify the
// inputs.
std::optional<TreeHash> MTH(const Subtree &subtree) {
if (!subtree.IsValid()) {
return std::nullopt;
}
SHA256_CTX ctx;
SHA256_Init(&ctx);
TreeHash out;
uint64_t n = subtree.Size();
if (n == 0) {
// The hash of an empty list is the hash of an empty string.
SHA256_Final(out.data(), &ctx);
return out;
}
if (n == 1) {
// One element in the list: return a leaf hash.
static const uint8_t header = 0x00;
auto leaf = data_->At(subtree.start);
SHA256_Update(&ctx, &header, 1);
SHA256_Update(&ctx, leaf.data(), leaf.size());
SHA256_Final(out.data(), &ctx);
return out;
}
// Only use the cache for subtrees with a size that is a power of 2.
uint64_t s = subtree.end - subtree.start;
bool use_cache = (s & (s - 1)) == 0;
if (use_cache) {
if (auto hash_opt = data_->NodeHash(subtree); hash_opt.has_value()) {
return *hash_opt;
}
}
// n elements in the list: MTH() is defined recursively.
auto left = MTH(subtree.Left());
auto right = MTH(subtree.Right());
if (!left.has_value() || !right.has_value()) {
return std::nullopt;
}
HashNode(*left, *right, out);
if (use_cache) {
data_->CacheNodeHash(subtree, out);
}
return out;
}
// Computes an inclusion proof to the element at index in subtree from start
// to end.
std::optional<std::vector<TreeHash>> InclusionProof(uint64_t index,
const Subtree &subtree) {
return SubtreeSubproof({index, index + 1}, subtree, true);
}
// Computes a consistency proof that |subtree| is contained in |tree|.
std::optional<std::vector<TreeHash>> ConsistencyProof(const Subtree &subtree,
const Subtree &tree) {
return SubtreeSubproof(subtree, tree, true);
}
private:
// Computes a SUBTREE_SUBPROOF from subtree to tree, where subtree is
// contained within tree.
std::optional<std::vector<TreeHash>> SubtreeSubproof(Subtree subtree,
const Subtree &tree,
bool known_hash) {
if (!subtree.IsValid() || !tree.IsValid() || !tree.Contains(subtree)) {
// Invalid inputs
return std::nullopt;
}
uint64_t n = tree.Size();
if (n == 0) {
// There must be a tree with contents for there to be a proof that
// something is in said tree.
return std::nullopt;
}
if (subtree == tree) {
if (known_hash) {
return std::vector<TreeHash>();
}
auto mth = MTH(tree);
if (!mth.has_value()) {
return std::nullopt;
}
return {{*mth}};
}
uint64_t k = tree.Split();
Subtree subproof_tree, mth_tree;
if (subtree.end <= k) {
subproof_tree = tree.Left();
mth_tree = tree.Right();
} else if (subtree.start >= k) {
mth_tree = tree.Left();
subproof_tree = tree.Right();
} else {
subtree.start = k;
mth_tree = tree.Left();
subproof_tree = tree.Right();
known_hash = false;
}
auto subproof = SubtreeSubproof(subtree, subproof_tree, known_hash);
auto mth = MTH(mth_tree);
if (!subproof.has_value() || !mth.has_value()) {
return std::nullopt;
}
subproof->push_back(*mth);
return subproof;
}
Data *data_;
};
size_t countr_zero(uint64_t n) {
if (n == 0) {
return 8 * sizeof(n);
}
size_t count = 0;
while ((n & 1) == 0) {
n >>= 1;
count++;
}
return count;
}
class ConcatData : public MerkleTree::Data {
public:
explicit ConcatData(Span<const uint8_t> label)
: label_(label.begin(), label.end()) {}
std::vector<uint8_t> At(uint64_t index) override {
std::vector<uint8_t> out(label_.size() + sizeof(index));
memcpy(out.data(), label_.data(), label_.size());
memcpy(out.data() + label_.size(), &index, sizeof(index));
return out;
}
std::optional<TreeHash> NodeHash(Subtree node) override {
size_t level_index = countr_zero(node.Size()) - 1;
if (node_cache_.size() <= level_index) {
return std::nullopt;
}
size_t position_index = node.start / node.Size();
if (node_cache_[level_index].size() <= position_index) {
return std::nullopt;
}
return node_cache_[level_index][position_index];
}
void CacheNodeHash(Subtree node, TreeHash hash) override {
size_t level_index = countr_zero(node.Size()) - 1;
size_t position_index = node.start / node.Size();
if (level_index >= node_cache_.size()) {
BSSL_CHECK(level_index == node_cache_.size());
node_cache_.push_back({});
}
if (position_index < node_cache_[level_index].size()) {
node_cache_[level_index][position_index] = hash;
return;
}
node_cache_[level_index].resize(position_index);
node_cache_[level_index].push_back(hash);
}
private:
std::vector<uint8_t> label_;
std::vector<std::vector<std::optional<TreeHash>>> node_cache_;
};
TEST(MerkleTreeTest, SubtreeIsValid) {
// An empty subtree is valid.
EXPECT_TRUE((Subtree{0, 0}.IsValid()));
// But if the end is before start, it's invalid.
EXPECT_FALSE((Subtree{1, 0}.IsValid()));
// A subtree of the maximum expressible size is valid.
EXPECT_TRUE((Subtree{0, std::numeric_limits<uint64_t>::max()}.IsValid()));
// Subtrees don't have to start at 0.
EXPECT_TRUE((Subtree{4, 8}.IsValid()));
// But if they don't start at 0, there's a limit to how big they can be.
EXPECT_FALSE((Subtree{4, 9}.IsValid()));
// Subtrees can have a ragged right edge.
EXPECT_TRUE((Subtree{4, 6}.IsValid()));
EXPECT_TRUE((Subtree{0, 6}.IsValid()));
}
TEST(MerkleTreeTest, SubtreeSplit) {
// Empty subtree.
EXPECT_EQ((Subtree{24601, 24601}).Split(), 24601ul);
// Single-item subtree.
EXPECT_EQ((Subtree{1336, 1337}).Split(), 1337ul);
// Two items in subtree.
EXPECT_EQ((Subtree{42, 44}).Split(), 43ul);
// Subtree size is 1 less than a power of 2.
EXPECT_EQ((Subtree{0, 31}).Split(), 16ul);
// Subtree size is a power of 2.
EXPECT_EQ((Subtree{64, 128}).Split(), 96ul);
/// Subtree size is 1 more than a power of 2.
EXPECT_EQ((Subtree{0, 257}).Split(), 256ul);
static const uint64_t u64_max = std::numeric_limits<uint64_t>::max();
// Maximum size tree.
EXPECT_EQ((Subtree{0, u64_max}).Split(), 1ull << 63);
// Small tree, with end at maximum value.
EXPECT_EQ((Subtree{u64_max - 3, u64_max}).Split(), u64_max - 1);
}
std::vector<uint8_t> ConcatProof(const std::vector<TreeHash> &proof) {
std::vector<uint8_t> out;
for (const auto &p : proof) {
out.insert(out.end(), p.begin(), p.end());
}
return out;
}
TEST(MerkleTreeTest, VerifySubtreeConsistencyProof) {
ConcatData tree_data(StringAsBytes("label"));
MerkleTree tree(&tree_data);
uint64_t index = 0;
auto node_hash = tree.MTH({index, index + 1});
Subtree subtree{0, 16};
auto proof = tree.InclusionProof(index, subtree);
ASSERT_TRUE(proof.has_value());
auto root_hash = EvaluateMerkleSubtreeConsistencyProof(
subtree.end, {index, index + 1}, ConcatProof(*proof), *node_hash);
ASSERT_TRUE(root_hash.has_value());
EXPECT_EQ(root_hash, tree.MTH(subtree));
}
// Test that the computed consistency proofs match the examples given in RFC
// 9162 section 2.1.5.
TEST(MerkleTreeTest, SubtreeConsistencyProofRFC9162) {
ConcatData tree_data(StringAsBytes("label"));
MerkleTree tree(&tree_data);
// The example from section 2.1.5 has a final tree with 7 leaves.
Subtree final_tree{0, 7};
// The examples refer to letters representing the MTH of various subtrees
// within that tree.
// a isn't used in any of the examples.
auto b = tree.MTH({1, 2});
auto c = tree.MTH({2, 3});
auto d = tree.MTH({3, 4});
// e isn't used in any of the examples.
auto f = tree.MTH({5, 6});
auto g = tree.MTH({0, 2});
auto h = tree.MTH({2, 4});
auto i = tree.MTH({4, 6});
auto j = tree.MTH({6, 7});
auto k = tree.MTH({0, 4});
auto l = tree.MTH({4, 7});
// Inclusion proofs:
// Section 2.1.5: "The inclusion proof for `d0` is `[b, h, l]`."
auto d0_proof = tree.InclusionProof(0, final_tree);
EXPECT_THAT(*d0_proof, testing::ElementsAre(*b, *h, *l));
// Section 2.1.5: "The inclusion proof for `d3` is `[c, g, l]`."
auto d3_proof = tree.InclusionProof(3, final_tree);
EXPECT_THAT(*d3_proof, testing::ElementsAre(*c, *g, *l));
// Section 2.1.5: "The inclusion proof for `d4` is `[f, j, k]`."
auto d4_proof = tree.InclusionProof(4, final_tree);
EXPECT_THAT(*d4_proof, testing::ElementsAre(*f, *j, *k));
// Section 2.1.5: "The inclusion proof for `d6` is `[i, k]`."
auto d6_proof = tree.InclusionProof(6, final_tree);
EXPECT_THAT(*d6_proof, testing::ElementsAre(*i, *k));
// Consistency proofs:
// The consistency proofs refer to the lettered MTHs above, as well as some
// MTHs representing the tree as it was incrementally built.
Subtree hash0_subtree = {0, 3};
Subtree hash1_subtree = {0, 4};
auto hash1 = tree.MTH(hash1_subtree);
ASSERT_EQ(hash1, k);
Subtree hash2_subtree = {0, 6};
// "The consistency proof between hash0 and hash is [c, d, g, l]."
auto hash0_proof = tree.ConsistencyProof(hash0_subtree, final_tree);
EXPECT_THAT(*hash0_proof, testing::ElementsAre(*c, *d, *g, *l));
// "The consistency proof beween hash1 and hash is [l]."
auto hash1_proof = tree.ConsistencyProof(hash1_subtree, final_tree);
EXPECT_THAT(*hash1_proof, testing::ElementsAre(*l));
// "The consistency proof between hash2 and hash is [i, j, k]."
auto hash2_proof = tree.ConsistencyProof(hash2_subtree, final_tree);
EXPECT_THAT(*hash2_proof, testing::ElementsAre(*i, *j, *k));
}
TEST(MerkleTreeTest, ValidProofsTest) {
ConcatData tree_data(StringAsBytes("label"));
MerkleTree tree(&tree_data);
uint64_t n = 4, start = 0, end = 3;
Subtree full_tree{0, n};
auto tree_hash = tree.MTH(full_tree);
Subtree subtree{start, end};
ASSERT_TRUE(subtree.IsValid());
auto subtree_hash = tree.MTH(subtree);
auto proof = tree.ConsistencyProof(subtree, full_tree);
ASSERT_TRUE(proof.has_value());
auto computed_hash = EvaluateMerkleSubtreeConsistencyProof(
n, subtree, ConcatProof(*proof), *subtree_hash);
EXPECT_EQ(computed_hash, tree_hash);
}
TEST(MerkleTreeTest, ValidProofs) {
ConcatData tree_data(StringAsBytes("label"));
MerkleTree tree(&tree_data);
// As of the time of writing this test, a run was performed with limit=257 and
// the test passed (but it took 1.7 seconds to run). This value is set to 129
// to balance how much of the space to explore with test execution time.
uint64_t limit = 129;
for (uint64_t n = 0; n < limit; n++) {
Subtree full_tree{0, n};
auto tree_hash = tree.MTH(full_tree);
for (uint64_t end = 0; end <= n; end++) {
for (uint64_t start = 0; start < end; start++) {
Subtree subtree{start, end};
if (!subtree.IsValid()) {
continue;
}
SCOPED_TRACE(testing::Message() << "Tree n=" << n << ", start: "
<< start << ", end: " << end);
auto subtree_hash = tree.MTH(subtree);
auto proof = tree.ConsistencyProof(subtree, full_tree);
ASSERT_TRUE(proof.has_value());
auto computed_hash = EvaluateMerkleSubtreeConsistencyProof(
n, subtree, ConcatProof(*proof), *subtree_hash);
EXPECT_EQ(computed_hash, tree_hash);
}
}
}
}
class ConstData : public MerkleTree::Data {
public:
// A tree that uses ConstData has the same data at every index in the tree.
std::vector<uint8_t> At(uint64_t index) override { return {}; }
std::optional<TreeHash> NodeHash(Subtree node) override {
size_t index = countr_zero(node.Size()) - 1;
if (node_cache_.size() <= index) {
return std::nullopt;
}
return node_cache_[index];
}
void CacheNodeHash(Subtree node, TreeHash hash) override {
size_t index = countr_zero(node.Size()) - 1;
// Cache entries are only inserted if not present, and always inserted from
// lowest level of the tree to highest.
BSSL_CHECK(index == node_cache_.size());
node_cache_.push_back(hash);
}
private:
std::vector<TreeHash> node_cache_;
};
TEST(MerkleTreeTest, VeryLargeProofs) {
ConstData tree_data;
MerkleTree tree(&tree_data);
Subtree fullest_tree = {0, std::numeric_limits<uint64_t>::max()};
auto root_hash = tree.MTH(fullest_tree);
ASSERT_TRUE(root_hash.has_value());
Subtree test_subtrees[] = {
fullest_tree,
{0, 1},
{0, 1ull << 63},
{1ull << 63, std::numeric_limits<uint64_t>::max()},
{std::numeric_limits<uint64_t>::max() - 1,
std::numeric_limits<uint64_t>::max()},
};
for (auto subtree : test_subtrees) {
SCOPED_TRACE(testing::Message() << "Subtree start: " << subtree.start
<< ", end: " << subtree.end);
auto proof = tree.ConsistencyProof(subtree, fullest_tree);
ASSERT_TRUE(proof.has_value());
auto computed_root_hash = EvaluateMerkleSubtreeConsistencyProof(
fullest_tree.end, subtree, ConcatProof(*proof), *tree.MTH(subtree));
EXPECT_EQ(computed_root_hash, root_hash);
}
}
} // namespace
BSSL_NAMESPACE_END