Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,5 @@ testdb
build
cmake-build-*
build-*
search-tests

2 changes: 1 addition & 1 deletion src/cluster/sync_migrate_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
class SyncMigrateContext : private EvbufCallbackBase<SyncMigrateContext, false>,
private EventCallbackBase<SyncMigrateContext> {
public:
SyncMigrateContext(Server *srv, redis::Connection *conn, int timeout) : srv_(srv), conn_(conn), timeout_(timeout){};
SyncMigrateContext(Server *srv, redis::Connection *conn, int timeout) : srv_(srv), conn_(conn), timeout_(timeout) {};

void Suspend();
void Resume(const Status &migrate_result);
Expand Down
2 changes: 1 addition & 1 deletion src/commands/command_parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

template <typename Iter>
struct MoveIterator : Iter {
explicit MoveIterator(Iter iter) : Iter(iter){};
explicit MoveIterator(Iter iter) : Iter(iter) {};

typename Iter::value_type&& operator*() const { return std::move(this->Iter::operator*()); }
};
Expand Down
5 changes: 1 addition & 4 deletions src/commands/commander.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,7 @@ StatusOr<std::vector<int>> CommandTable::GetKeysFromCommand(const CommandAttribu
[&](const std::vector<std::string> &, CommandKeyRange key_range) {
key_range.ForEachKeyIndex([&](int i) { key_indexes.push_back(i); }, cmd_tokens.size());
},
cmd_tokens,
[&](const auto &) {
status = {Status::NotOK, "The command has no key arguments"};
});
cmd_tokens, [&](const auto &) { status = {Status::NotOK, "The command has no key arguments"}; });

if (!status) {
return status;
Expand Down
2 changes: 1 addition & 1 deletion src/common/cron.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct CronPattern {

struct Interval {
int interval;
}; // */n
}; // */n
struct Any {}; // *
using Numbers = std::vector<std::variant<Number, Range>>; // 1,2,3-6,7

Expand Down
4 changes: 2 additions & 2 deletions src/common/rdb_stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class RdbStream {

class RdbStringStream : public RdbStream {
public:
explicit RdbStringStream(std::string_view input) : input_(input){};
explicit RdbStringStream(std::string_view input) : input_(input) {};
RdbStringStream(const RdbStringStream &) = delete;
RdbStringStream &operator=(const RdbStringStream &) = delete;
~RdbStringStream() override = default;
Expand All @@ -65,7 +65,7 @@ class RdbStringStream : public RdbStream {
class RdbFileStream : public RdbStream {
public:
explicit RdbFileStream(std::string file_name, size_t chunk_size = 1024 * 1024)
: file_name_(std::move(file_name)), check_sum_(0), total_read_bytes_(0), max_read_chunk_size_(chunk_size){};
: file_name_(std::move(file_name)), check_sum_(0), total_read_bytes_(0), max_read_chunk_size_(chunk_size) {};
RdbFileStream(const RdbFileStream &) = delete;
RdbFileStream &operator=(const RdbFileStream &) = delete;
~RdbFileStream() override = default;
Expand Down
2 changes: 1 addition & 1 deletion src/common/status.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ struct StringInStatusOr<T, std::enable_if_t<sizeof(T) < sizeof(std::string)>> :
StringInStatusOr(StringInStatusOr<U>&& v) : BaseType(new std::string(*std::move(v))) {} // NOLINT
template <typename U, typename std::enable_if_t<!StringInStatusOr<U>::inplace, int> = 0>
StringInStatusOr(StringInStatusOr<U>&& v) // NOLINT
: BaseType((typename StringInStatusOr<U>::BaseType &&)(std::move(v))) {}
: BaseType((typename StringInStatusOr<U>::BaseType&&)(std::move(v))) {}

StringInStatusOr(const StringInStatusOr& v) = delete;

Expand Down
3 changes: 1 addition & 2 deletions src/common/string_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ std::string StringJoin(const T &con, F &&f, std::string_view sep = ", ") {

template <typename T>
std::string StringJoin(const T &con, std::string_view sep = ", ") {
return StringJoin(
con, [](const auto &v) -> decltype(auto) { return v; }, sep);
return StringJoin(con, [](const auto &v) -> decltype(auto) { return v; }, sep);
}

} // namespace util
4 changes: 2 additions & 2 deletions src/config/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ Status SetRocksdbCompression(Server *srv, const rocksdb::CompressionType compres
for (size_t i = compression_start_level; i < KVROCKS_MAX_LSM_LEVEL; i++) {
compression_per_level_builder.emplace_back(compression_option);
}
const std::string compression_per_level = util::StringJoin(
compression_per_level_builder, [](const auto &s) -> decltype(auto) { return s; }, ":");
const std::string compression_per_level =
util::StringJoin(compression_per_level_builder, [](const auto &s) -> decltype(auto) { return s; }, ":");
return srv->storage->SetOptionForAllColumnFamilies("compression_per_level", compression_per_level);
};

Expand Down
3 changes: 1 addition & 2 deletions src/search/interval.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,8 +113,7 @@ struct IntervalSet {

std::string ToString() const {
if (IsEmpty()) return "empty set";
return util::StringJoin(
intervals, [](const auto &i) { return Interval(i.first, i.second).ToString(); }, " or ");
return util::StringJoin(intervals, [](const auto &i) { return Interval(i.first, i.second).ToString(); }, " or ");
}

friend std::ostream &operator<<(std::ostream &os, const IntervalSet &is) { return os << is.ToString(); }
Expand Down
2 changes: 1 addition & 1 deletion src/search/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ struct NumericCompareExpr : BoolAtomExpr {
struct VectorLiteral : Literal {
std::vector<double> values;

explicit VectorLiteral(std::vector<double> &&values) : values(std::move(values)){};
explicit VectorLiteral(std::vector<double> &&values) : values(std::move(values)) {};

std::string_view Name() const override { return "VectorLiteral"; }
std::string Dump() const override {
Expand Down
268 changes: 268 additions & 0 deletions src/search/passes/egraph.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*
*/

#include "search/passes/egraph.h"

#include <algorithm>
#include <queue>
#include <unordered_map>
#include <unordered_set>

namespace kqir {

// Implementation of EClass::add
void EClass::add(ENode node) { nodes_.insert(std::move(node)); }

// Implementation of EGraph::add
ClassId EGraph::add(ENode node) {
// Create new class IDs for each child node if not already in the e-graph
std::vector<ClassId> new_children;
new_children.reserve(node.children().size());

for (ClassId child : node.children()) {
new_children.push_back(find_mutable(child));
}

// Update node with canonical class IDs
ENode new_node(node.op(), std::move(new_children));

// Check if this node already exists in the e-graph
for (auto& [id, eclass] : classes_) {
for (const auto& existing_node : eclass.nodes()) {
if (existing_node == new_node) {
return id;
}
}
}

// Create a new class for this node
ClassId id = next_id_++;
classes_.emplace(id, EClass(id));
parents_[id] = id;
classes_.at(id).add(std::move(new_node));

return id;
}

// Implementation of EGraph::find_mutable (non-const version)
ClassId EGraph::find_mutable(ClassId id) {
if (parents_.count(id) == 0) {
// Not found, return the original ID
return id;
}

// Path compression for union-find
if (parents_[id] != id) {
parents_[id] = find_mutable(parents_[id]);
}

return parents_[id];
}

// Implementation of EGraph::find (const version)
ClassId EGraph::find(ClassId id) const {
if (parents_.count(id) == 0) {
// Not found, return the original ID
return id;
}

// For const version, we can't do path compression
if (parents_.at(id) != id) {
return find(parents_.at(id));
}

return parents_.at(id);
}

// Implementation of EGraph::merge
ClassId EGraph::merge(ClassId id1, ClassId id2) {
ClassId root1 = find_mutable(id1);
ClassId root2 = find_mutable(id2);

if (root1 == root2) {
return root1;
}

// Union by rank (or just picking the first one for simplicity)
parents_[root2] = root1;

// Merge the equivalence classes
for (const auto& node : classes_.at(root2).nodes()) {
classes_.at(root1).add(node);
}

return root1;
}

// Implementation of EGraph::get_class
const EClass& EGraph::get_class(ClassId id) const { return classes_.at(id); }

// Implementation of EGraph::add_node
ClassId EGraph::add_node(const Node* node) {
if (node == nullptr) {
return 0; // Special ID for null nodes
}

// Create an ENode representation based on the KQIR node type
std::string op = std::string(node->Name());
std::vector<ClassId> children;

// Process child nodes recursively
for (auto it = const_cast<Node*>(node)->ChildBegin(); it != const_cast<Node*>(node)->ChildEnd(); ++it) {
Node* child = *it;
children.push_back(add_node(child));
}

// Add content to the op name to distinguish literals, field references, etc.
if (!node->Content().empty()) {
op += ":" + node->Content();
}

return add(ENode(op, std::move(children)));
}

// Implementation of EGraph::extract_best
std::unique_ptr<Node> EGraph::extract_best() {
// This default extraction just creates a new node tree based on the structure
// of the e-graph. A more sophisticated implementation would use a cost model.

std::unordered_map<ClassId, std::unique_ptr<Node>> extracted;

// Function to recursively extract nodes
std::function<std::unique_ptr<Node>(ClassId)> extract_recursive = [&](ClassId id) -> std::unique_ptr<Node> {
id = find_mutable(id);

// If already extracted, return a clone
if (extracted.count(id) > 0) {
return extracted.at(id)->Clone();
}

// Get the best node from this equivalence class
const EClass& eclass = get_class(id);
std::unique_ptr<Node> best_node;

// Find the first node that can be reconstructed
for (const auto& enode : eclass.nodes()) {
// Extract children first
std::vector<std::unique_ptr<Node>> child_nodes;
bool all_children_extracted = true;

for (ClassId child_id : enode.children()) {
auto child_node = extract_recursive(child_id);
if (child_node) {
child_nodes.push_back(std::move(child_node));
} else {
all_children_extracted = false;
break;
}
}

if (!all_children_extracted) {
continue;
}

// Create a new node based on the operator type
std::string_view op_name = enode.op();

// Parse the operator name to handle content-enhanced op names
std::string op_str(op_name);
std::string content;
size_t colon_pos = op_str.find(':');

if (colon_pos != std::string::npos) {
content = op_str.substr(colon_pos + 1);
op_str = op_str.substr(0, colon_pos);
}

// This is a simplified reconstruction that would need to be expanded
// based on the actual node types in your system
if (op_str == "Filter") {
if (child_nodes.size() == 2) {
auto source = Node::MustAs<PlanOperator>(std::move(child_nodes[0]));
auto filter_expr = Node::MustAs<QueryExpr>(std::move(child_nodes[1]));
best_node = std::make_unique<Filter>(std::move(source), std::move(filter_expr));
}
} else if (op_str == "Merge") {
std::vector<std::unique_ptr<PlanOperator>> ops;
for (auto& child : child_nodes) {
ops.push_back(Node::MustAs<PlanOperator>(std::move(child)));
}
best_node = std::make_unique<Merge>(std::move(ops));
}
// Add more node types as needed...

if (best_node) {
break;
}
}

if (best_node) {
extracted[id] = best_node->Clone();
}

return best_node;
};

// Start extraction from the root nodes
for (const auto& [id, _] : classes_) {
if (id == find(id)) { // Only consider canonical classes
auto node = extract_recursive(id);
if (node) {
return node;
}
}
}

return nullptr;
}

// Implementation of RuleSet::add
void RuleSet::add(std::unique_ptr<Rewrite> rule) { rules_.push_back(std::move(rule)); }

// Implementation of RuleSet::run_until_saturation
void RuleSet::run_until_saturation(EGraph& egraph, size_t max_iterations) {
size_t iterations = 0;
size_t prev_size = 0;

// Run until we reach max iterations or the e-graph stops growing
while (iterations < max_iterations) {
// Calculate the current size of the e-graph
size_t current_size = 0;
for (const auto& [_, eclass] : egraph.get_classes()) {
current_size += eclass.nodes().size();
}

// Check if we've reached saturation
if (iterations > 0 && current_size == prev_size) {
break;
}

prev_size = current_size;

// Apply all rewrite rules
for (auto& rule : rules_) {
rule->apply(egraph);
}

iterations++;
}
}

} // namespace kqir
Loading