diff --git a/category/mpt/test/db_test.cpp b/category/mpt/test/db_test.cpp index c24b081c2a..ea4433aabe 100644 --- a/category/mpt/test/db_test.cpp +++ b/category/mpt/test/db_test.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include #include @@ -1658,6 +1659,101 @@ TYPED_TEST(DbTraverseTest, trimmed_traverse) } } +TEST(RangedGetTest, path_exceeds_min_prefix) +{ + StateMachineAlwaysVarLen machine; + Db db{machine}; + Node::SharedPtr root; + + Nibbles const min{0x0124_bytes}; + Nibbles const max{0x1234_bytes}; + + // in range: longer than min + Nibbles const k_in_long{0x10000000_bytes}; + Nibbles const k_in_long_2{0x02000000_bytes}; + // in range: shorter than min (0x02 > 0x0124) + Nibbles const k_in_short{0x02_bytes}; + // in range: equals min (inclusive) + Nibbles const k_in_at_min{0x0124_bytes}; + // out of range: above max (long) + Nibbles const k_out_long{0x20000000_bytes}; + // out of range: above max (short) + Nibbles const k_out_short{0x13_bytes}; + // out of range: equals max (exclusive) + Nibbles const k_out_at_max{0x1234_bytes}; + auto const val = 0xdeadbeef_bytes; + + uint64_t const block_id = 0; + auto u1 = make_update(k_in_long, val); + auto u2 = make_update(k_in_long_2, val); + auto u3 = make_update(k_in_short, val); + auto u4 = make_update(k_in_at_min, val); + auto u5 = make_update(k_out_long, val); + auto u6 = make_update(k_out_at_max, val); + auto u7 = make_update(k_out_short, val); + UpdateList ul; + ul.push_front(u1); + ul.push_front(u2); + ul.push_front(u3); + ul.push_front(u4); + ul.push_front(u5); + ul.push_front(u6); + ul.push_front(u7); + root = db.upsert(std::move(root), std::move(ul), block_id); + + size_t num_results = 0; + RangedGetMachine range_machine{ + min, + max, + [&num_results](NibblesView const, monad::byte_string_view const) { + ++num_results; + }}; + ASSERT_TRUE(db.traverse(root, range_machine, block_id)); + EXPECT_EQ(num_results, 4); +} + +TEST(RangedGetTest, path_longer_than_min_and_max) +{ + StateMachineAlwaysVarLen machine; + Db db{machine}; + Node::SharedPtr root; + + // Very short bounds: all keys will have paths longer than both min and max. + Nibbles const min{0x03_bytes}; + Nibbles const max{0x05_bytes}; + + Nibbles const k_in_1{0x03000000_bytes}; + Nibbles const k_in_2{0x04ffffff_bytes}; + Nibbles const k_out_below{0x02ffffff_bytes}; + Nibbles const k_out_at_max{0x05000000_bytes}; + Nibbles const k_out_above{0x06000000_bytes}; + auto const val = 0xdeadbeef_bytes; + + uint64_t const block_id = 0; + UpdateList ul; + auto u1 = make_update(k_in_1, val); + auto u2 = make_update(k_in_2, val); + auto u3 = make_update(k_out_below, val); + auto u4 = make_update(k_out_at_max, val); + auto u5 = make_update(k_out_above, val); + ul.push_front(u1); + ul.push_front(u2); + ul.push_front(u3); + ul.push_front(u4); + ul.push_front(u5); + root = db.upsert(std::move(root), std::move(ul), block_id); + + size_t num_results = 0; + RangedGetMachine range_machine{ + min, + max, + [&num_results](NibblesView const, monad::byte_string_view const) { + ++num_results; + }}; + ASSERT_TRUE(db.traverse(root, range_machine, block_id)); + EXPECT_EQ(num_results, 2); +} + TEST_F(OnDiskDbFixture, rw_query_old_version) { auto const &kv = fixed_updates::kv; diff --git a/category/mpt/traverse_util.hpp b/category/mpt/traverse_util.hpp index 89910bc369..30245e94ab 100644 --- a/category/mpt/traverse_util.hpp +++ b/category/mpt/traverse_util.hpp @@ -23,6 +23,8 @@ #include #include +#include + MONAD_MPT_NAMESPACE_BEGIN using TraverseCallback = std::function; @@ -35,21 +37,16 @@ class RangedGetMachine : public TraverseMachine TraverseCallback callback_; private: - // This function is a looser version checking if min <= path < max. But it - // will also return true if we should continue traversing down. Suppose we - // have the range [0x00, 0x10] and we are at node 0x0. In that case the - // path's size is less than the min, check if it's as substring. + // Check if any descendant of path could fall within [min, max). When + // path is shorter than min, compare path against the truncated min + // prefix: e.g. range [0x0124, 0x1234) at path 0x1 should continue + // because descendants like 0x1000 are in range (0x1 >= 0x0). bool does_key_intersect_with_range(NibblesView const path) { - bool const min_check = [this, path] { - if (path.nibble_size() < min_.nibble_size()) { - return NibblesView{min_}.starts_with(path); - } - else { - return (path >= min_); - } - }(); - return min_check && (path < NibblesView{max_}); + auto const prefix_len = + std::min(path.nibble_size(), NibblesView{min_}.nibble_size()); + auto const min_prefix = NibblesView{min_}.substr(0, prefix_len); + return path >= min_prefix && path < NibblesView{max_}; } public: @@ -68,14 +65,14 @@ class RangedGetMachine : public TraverseMachine return true; } - auto next_path = + Nibbles next_path = concat(NibblesView{path_}, branch, node.path_nibble_view()); if (!does_key_intersect_with_range(next_path)) { return false; } path_ = std::move(next_path); - if (node.has_value() && path_.nibble_size() >= min_.nibble_size()) { + if (node.has_value() && path_ >= NibblesView{min_}) { callback_(path_, node.value()); }