Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions category/mpt/test/db_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include <category/mpt/ondisk_db_config.hpp>
#include <category/mpt/test/test_fixtures_gtest.hpp>
#include <category/mpt/traverse.hpp>
#include <category/mpt/traverse_util.hpp>
#include <category/mpt/trie.hpp>
#include <category/mpt/update.hpp>
#include <category/mpt/util.hpp>
Expand Down Expand Up @@ -1658,6 +1659,101 @@ TYPED_TEST(DbTraverseTest, trimmed_traverse)
}
}

TEST(RangedGetTest, path_exceeds_min_prefix)
{
StateMachineAlwaysVarLen machine;
Db db{machine};
Node::SharedPtr root;

Nibbles const min{0x0124_bytes};
Nibbles const max{0x1234_bytes};

// in range: longer than min
Nibbles const k_in_long{0x10000000_bytes};
Nibbles const k_in_long_2{0x02000000_bytes};
// in range: shorter than min (0x02 > 0x0124)
Nibbles const k_in_short{0x02_bytes};
// in range: equals min (inclusive)
Nibbles const k_in_at_min{0x0124_bytes};
// out of range: above max (long)
Nibbles const k_out_long{0x20000000_bytes};
// out of range: above max (short)
Nibbles const k_out_short{0x13_bytes};
// out of range: equals max (exclusive)
Nibbles const k_out_at_max{0x1234_bytes};
auto const val = 0xdeadbeef_bytes;

uint64_t const block_id = 0;
auto u1 = make_update(k_in_long, val);
auto u2 = make_update(k_in_long_2, val);
auto u3 = make_update(k_in_short, val);
auto u4 = make_update(k_in_at_min, val);
auto u5 = make_update(k_out_long, val);
auto u6 = make_update(k_out_at_max, val);
auto u7 = make_update(k_out_short, val);
UpdateList ul;
ul.push_front(u1);
ul.push_front(u2);
ul.push_front(u3);
ul.push_front(u4);
ul.push_front(u5);
ul.push_front(u6);
ul.push_front(u7);
root = db.upsert(std::move(root), std::move(ul), block_id);

size_t num_results = 0;
RangedGetMachine range_machine{
min,
max,
[&num_results](NibblesView const, monad::byte_string_view const) {
++num_results;
}};
ASSERT_TRUE(db.traverse(root, range_machine, block_id));
EXPECT_EQ(num_results, 4);
}

TEST(RangedGetTest, path_longer_than_min_and_max)
{
StateMachineAlwaysVarLen machine;
Db db{machine};
Node::SharedPtr root;

// Very short bounds: all keys will have paths longer than both min and max.
Nibbles const min{0x03_bytes};
Nibbles const max{0x05_bytes};

Nibbles const k_in_1{0x03000000_bytes};
Nibbles const k_in_2{0x04ffffff_bytes};
Nibbles const k_out_below{0x02ffffff_bytes};
Nibbles const k_out_at_max{0x05000000_bytes};
Nibbles const k_out_above{0x06000000_bytes};
auto const val = 0xdeadbeef_bytes;

uint64_t const block_id = 0;
UpdateList ul;
auto u1 = make_update(k_in_1, val);
auto u2 = make_update(k_in_2, val);
auto u3 = make_update(k_out_below, val);
auto u4 = make_update(k_out_at_max, val);
auto u5 = make_update(k_out_above, val);
ul.push_front(u1);
ul.push_front(u2);
ul.push_front(u3);
ul.push_front(u4);
ul.push_front(u5);
root = db.upsert(std::move(root), std::move(ul), block_id);

size_t num_results = 0;
RangedGetMachine range_machine{
min,
max,
[&num_results](NibblesView const, monad::byte_string_view const) {
++num_results;
}};
ASSERT_TRUE(db.traverse(root, range_machine, block_id));
EXPECT_EQ(num_results, 2);
}

TEST_F(OnDiskDbFixture, rw_query_old_version)
{
auto const &kv = fixed_updates::kv;
Expand Down
27 changes: 12 additions & 15 deletions category/mpt/traverse_util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include <category/mpt/traverse.hpp>
#include <category/mpt/util.hpp>

#include <algorithm>

MONAD_MPT_NAMESPACE_BEGIN

using TraverseCallback = std::function<void(NibblesView, byte_string_view)>;
Expand All @@ -35,21 +37,16 @@ class RangedGetMachine : public TraverseMachine
TraverseCallback callback_;

private:
// This function is a looser version checking if min <= path < max. But it
// will also return true if we should continue traversing down. Suppose we
// have the range [0x00, 0x10] and we are at node 0x0. In that case the
// path's size is less than the min, check if it's as substring.
// Check if any descendant of path could fall within [min, max). When
// path is shorter than min, compare path against the truncated min
// prefix: e.g. range [0x0124, 0x1234) at path 0x1 should continue
// because descendants like 0x1000 are in range (0x1 >= 0x0).
bool does_key_intersect_with_range(NibblesView const path)
{
bool const min_check = [this, path] {
if (path.nibble_size() < min_.nibble_size()) {
return NibblesView{min_}.starts_with(path);
}
else {
return (path >= min_);
}
}();
return min_check && (path < NibblesView{max_});
auto const prefix_len =
std::min(path.nibble_size(), NibblesView{min_}.nibble_size());
auto const min_prefix = NibblesView{min_}.substr(0, prefix_len);
return path >= min_prefix && path < NibblesView{max_};
}

public:
Expand All @@ -68,14 +65,14 @@ class RangedGetMachine : public TraverseMachine
return true;
}

auto next_path =
Nibbles next_path =
concat(NibblesView{path_}, branch, node.path_nibble_view());
if (!does_key_intersect_with_range(next_path)) {
return false;
}

path_ = std::move(next_path);
if (node.has_value() && path_.nibble_size() >= min_.nibble_size()) {
if (node.has_value() && path_ >= NibblesView{min_}) {
callback_(path_, node.value());
}

Expand Down
Loading