diff --git a/.clangd b/.clangd new file mode 100644 index 000000000..3b4497187 --- /dev/null +++ b/.clangd @@ -0,0 +1,4 @@ +CompileFlags: + Add: + - "-ferror-limit=0" + - "-DRunningClangd" diff --git a/.config/typos.toml b/.config/typos.toml index 8a4db5958..6983cffd2 100644 --- a/.config/typos.toml +++ b/.config/typos.toml @@ -1,6 +1,7 @@ # See https://github.com/crate-ci/typos/blob/master/docs/reference.md to configure typos [files] +ignore-hidden = true extend-exclude = [ ] @@ -12,9 +13,14 @@ CrEaTe = "CrEaTe" LiSt = "LiSt" DeBuG = "DeBuG" DrOpInDeX = "DrOpInDeX" +Clangd = "Clangd" [type.cpp] extend-ignore-re = [ "baNAna", "eXIst", + "DIALEC", + "nghi", + "bubu", + "Teser" ] diff --git a/.vscode/cspell.json b/.vscode/cspell.json index 040ff0ce9..74887cd4e 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -14,6 +14,10 @@ "words": [ "absl", "bazel", + "bubu", + "Clandd", + "CLangd", + "DIALEC", "Externalizer", "highwayhash", "hnsw", @@ -21,11 +25,15 @@ "Inorder", "MRMW", "mstime", + "nghi", "NOLINTNEXTLINE", "nonexistentkey", "redis", "Redisearch", "synchronistically", + "Teser", + "upsert", + "upserts", "Valkey", "valkeysearch", "vmsdk" diff --git a/src/commands/CMakeLists.txt b/src/commands/CMakeLists.txt index 59320aae3..786f75998 100644 --- a/src/commands/CMakeLists.txt +++ b/src/commands/CMakeLists.txt @@ -12,6 +12,7 @@ set(SRCS_COMMANDS ${CMAKE_CURRENT_LIST_DIR}/ft_list.cc ${CMAKE_CURRENT_LIST_DIR}/ft_search.cc ${CMAKE_CURRENT_LIST_DIR}/commands.h + ${CMAKE_CURRENT_LIST_DIR}/commands.cc ${CMAKE_CURRENT_LIST_DIR}/ft_search.h) valkey_search_add_static_library(commands "${SRCS_COMMANDS}") diff --git a/src/commands/commands.cc b/src/commands/commands.cc new file mode 100644 index 000000000..24c23e65b --- /dev/null +++ b/src/commands/commands.cc @@ -0,0 +1,141 @@ +/* + * Copyright (c) 2025, valkey-search contributors + * All rights reserved. + * SPDX-License-Identifier: BSD 3-Clause + * + */ + +#include "src/commands/commands.h" + +#include "fanout.h" +#include "ft_create_parser.h" +#include "src/acl.h" +#include "src/commands/ft_search.h" +#include "src/query/search.h" +#include "src/schema_manager.h" +#include "src/valkey_search.h" +#include "valkey_search_options.h" +#include "vmsdk/src/debug.h" + +namespace valkey_search { +namespace async { + +struct Result { + cancel::Token cancellation_token; + absl::StatusOr> neighbors; + std::unique_ptr parameters; +}; + +int Timeout(ValkeyModuleCtx *ctx, [[maybe_unused]] ValkeyModuleString **argv, + [[maybe_unused]] int argc) { + return ValkeyModule_ReplyWithError( + ctx, "Search operation cancelled due to timeout"); +} + +int Reply(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { + auto *res = + static_cast(ValkeyModule_GetBlockedClientPrivateData(ctx)); + CHECK(res != nullptr); + if (!res->neighbors.ok()) { + ++Metrics::GetStats().query_failed_requests_cnt; + return ValkeyModule_ReplyWithError( + ctx, res->neighbors.status().message().data()); + } + res->parameters->SendReply(ctx, res->neighbors.value()); + return VALKEYMODULE_OK; +} + +void Free([[maybe_unused]] ValkeyModuleCtx *ctx, void *privdata) { + auto *result = static_cast(privdata); + delete result; +} + +} // namespace async + +CONTROLLED_BOOLEAN(ForceReplicasOnly, false); + +// +// Common Class for FT.SEARCH and FT.AGGREGATE command +// +absl::Status QueryCommand::Execute(ValkeyModuleCtx *ctx, + ValkeyModuleString **argv, int argc, + std::unique_ptr parameters) { + auto status = [&]() -> absl::Status { + auto &schema_manager = SchemaManager::Instance(); + vmsdk::ArgsIterator itr{argv + 1, argc - 1}; + parameters->timeout_ms = options::GetDefaultTimeoutMs().GetValue(); + VMSDK_RETURN_IF_ERROR( + vmsdk::ParseParamValue(itr, parameters->index_schema_name)); + VMSDK_ASSIGN_OR_RETURN( + parameters->index_schema, + SchemaManager::Instance().GetIndexSchema( + ValkeyModule_GetSelectedDb(ctx), parameters->index_schema_name)); + VMSDK_RETURN_IF_ERROR( + vmsdk::ParseParamValue(itr, parameters->parse_vars.query_string)); + VMSDK_RETURN_IF_ERROR(parameters->ParseCommand(itr)); + parameters->parse_vars.ClearAtEndOfParse(); + parameters->cancellation_token = + cancel::Make(parameters->timeout_ms, nullptr); + static const auto permissions = + PrefixACLPermissions(kSearchCmdPermissions, kSearchCommand); + VMSDK_RETURN_IF_ERROR(AclPrefixCheck( + ctx, permissions, parameters->index_schema->GetKeyPrefixes())); + + parameters->index_schema->ProcessMultiQueue(); + + const bool inside_multi_exec = vmsdk::MultiOrLua(ctx); + if (ABSL_PREDICT_FALSE(!ValkeySearch::Instance().SupportParallelQueries() || + inside_multi_exec)) { + VMSDK_ASSIGN_OR_RETURN( + auto neighbors, + query::Search(*parameters, query::SearchMode::kLocal)); + if (!options::GetEnablePartialResults().GetValue() && + parameters->cancellation_token->IsCancelled()) { + ValkeyModule_ReplyWithError( + ctx, "Search operation cancelled due to timeout"); + ++Metrics::GetStats().query_failed_requests_cnt; + return absl::OkStatus(); + } + parameters->SendReply(ctx, neighbors); + return absl::OkStatus(); + } + + vmsdk::BlockedClient blocked_client(ctx, async::Reply, async::Timeout, + async::Free, parameters->timeout_ms); + blocked_client.MeasureTimeStart(); + auto on_done_callback = [blocked_client = std::move(blocked_client)]( + auto &neighbors, auto parameters) mutable { + std::unique_ptr upcast_parameters( + dynamic_cast(parameters.release())); + CHECK(upcast_parameters != nullptr); + auto result = std::make_unique(async::Result{ + .neighbors = std::move(neighbors), + .parameters = std::move(upcast_parameters), + }); + blocked_client.SetReplyPrivateData(result.release()); + }; + + if (ValkeySearch::Instance().UsingCoordinator() && + ValkeySearch::Instance().IsCluster() && !parameters->local_only) { + auto mode = /* !vmsdk::IsReadOnly(ctx) ? query::fanout::kPrimaries ? */ + ForceReplicasOnly.GetValue() + ? query::fanout::FanoutTargetMode::kReplicasOnly + : query::fanout::FanoutTargetMode::kRandom; + auto search_targets = query::fanout::GetSearchTargetsForFanout(ctx, mode); + return query::fanout::PerformSearchFanoutAsync( + ctx, search_targets, + ValkeySearch::Instance().GetCoordinatorClientPool(), + std::move(parameters), ValkeySearch::Instance().GetReaderThreadPool(), + std::move(on_done_callback)); + } + return query::SearchAsync( + std::move(parameters), ValkeySearch::Instance().GetReaderThreadPool(), + std::move(on_done_callback), query::SearchMode::kLocal); + }(); + if (!status.ok()) { + ++Metrics::GetStats().query_failed_requests_cnt; + } + return status; +} + +} // namespace valkey_search diff --git a/src/commands/commands.h b/src/commands/commands.h index 3bc6f21f0..56502fb06 100644 --- a/src/commands/commands.h +++ b/src/commands/commands.h @@ -11,6 +11,8 @@ #include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" +#include "command_parser.h" +#include "src/query/search.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search { @@ -74,6 +76,41 @@ absl::Status FTDebugCmd(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc); absl::Status FTAggregateCmd(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc); + +// +// Common stuff for FT.SEARCH and FT.AGGREGATE command +// +struct QueryCommand : public query::SearchParameters { + QueryCommand() : query::SearchParameters(0, nullptr) {} + // + // Start of command. + // + static absl::Status Execute(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, + int argc, std::unique_ptr cmd); + + // + // Parse command (after index and query string) + // + virtual absl::Status ParseCommand(vmsdk::ArgsIterator &itr) = 0; + // + // Executed on Main Thread after merge + // + virtual void SendReply(ValkeyModuleCtx *ctx, + std::deque &neighbors) = 0; +}; + +namespace async { + +int Reply(ValkeyModuleCtx *ctx, [[maybe_unused]] ValkeyModuleString **argv, + [[maybe_unused]] int argc); + +int Timeout(ValkeyModuleCtx *ctx, [[maybe_unused]] ValkeyModuleString **argv, + [[maybe_unused]] int argc); + +void Free(ValkeyModuleCtx * /*ctx*/, void *privdata); + +} // namespace async + } // namespace valkey_search #endif // VALKEYSEARCH_SRC_COMMANDS_COMMANDS_H_ diff --git a/src/commands/ft_aggregate.cc b/src/commands/ft_aggregate.cc index 2daa8d190..cad82a533 100644 --- a/src/commands/ft_aggregate.cc +++ b/src/commands/ft_aggregate.cc @@ -4,28 +4,18 @@ * SPDX-License-Identifier: BSD 3-Clause */ -#include "src/commands/ft_aggregate.h" - -#include #include +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "ft_search_parser.h" +#include "src/commands/commands.h" #include "src/commands/ft_aggregate_exec.h" -#include "src/commands/ft_create_parser.h" -#include "src/commands/ft_search.h" #include "src/index_schema.h" #include "src/indexes/index_base.h" #include "src/metrics.h" -#include "src/query/fanout.h" #include "src/query/response_generator.h" -#include "src/schema_manager.h" -#include "src/valkey_search.h" -#include "vmsdk/src/debug.h" - -// #define DBG std::cerr -#define DBG 0 && std::cerr namespace valkey_search { namespace aggregate { @@ -51,43 +41,34 @@ struct RealIndexInterface : public IndexInterface { absl::Status ManipulateReturnsClause(AggregateParameters ¶ms) { // Figure out what fields actually need to be returned by the aggregation // operation. And modify the common search returns list accordingly - DBG << "Manipulating returns clause for: " << params.index_schema_name - << "\n"; CHECK(!params.no_content); if (params.loadall_) { - DBG << "**LOADALL**\n"; CHECK(params.return_attributes.empty()); } else if (params.loads_.empty()) { // Nothing, don't load anything params.no_content = true; } else { - DBG << "LOADING: "; for (const auto &load : params.loads_) { - DBG << " " << load; // // Skip loading of the score and the key, we always get those... // if (load == "__key") { - DBG << " *skipped*"; params.load_key = true; continue; } if (load == vmsdk::ToStringView(params.score_as.get())) { - DBG << " *skipping score*"; continue; } VMSDK_ASSIGN_OR_RETURN(auto indexer, params.index_schema->GetIndex(load)); - auto field_type = indexer->GetIndexerType(); + auto indexer_type = indexer->GetIndexerType(); auto schema_identifier = params.index_schema->GetIdentifier(load); if (schema_identifier.ok()) { - DBG << " (alias: " << *schema_identifier << ", " << load << ")"; params.return_attributes.emplace_back(query::ReturnAttribute{ .identifier = vmsdk::MakeUniqueValkeyString(*schema_identifier), .attribute_alias = vmsdk::MakeUniqueValkeyString(load), .alias = vmsdk::MakeUniqueValkeyString(load)}); - params.AddRecordAttribute(*schema_identifier, load, field_type); + params.AddRecordAttribute(*schema_identifier, load, indexer_type); } else { - DBG << " " << load; params.return_attributes.emplace_back(query::ReturnAttribute{ .identifier = vmsdk::MakeUniqueValkeyString(load), .attribute_alias = vmsdk::UniqueValkeyString(), @@ -95,123 +76,89 @@ absl::Status ManipulateReturnsClause(AggregateParameters ¶ms) { params.AddRecordAttribute(load, load, indexes::IndexerType::kNone); } } - DBG << "\n"; } return absl::OkStatus(); } -absl::StatusOr> ParseCommand( - ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc, - const SchemaManager &schema_manager) { +absl::Status AggregateParameters::ParseCommand(vmsdk::ArgsIterator &itr) { static vmsdk::KeyValueParser parser = CreateAggregateParser(); - vmsdk::ArgsIterator itr{argv, argc}; - std::string index_schema_name; - VMSDK_RETURN_IF_ERROR(vmsdk::ParseParamValue(itr, index_schema_name)); - VMSDK_ASSIGN_OR_RETURN( - auto index_schema, - SchemaManager::Instance().GetIndexSchema(ValkeyModule_GetSelectedDb(ctx), - index_schema_name)); - RealIndexInterface index_interface(index_schema); - auto params = std::make_unique( - options::GetDefaultTimeoutMs().GetValue(), &index_interface); - DBG << "AggregateParameters created for index: " << index_schema_name << " @" - << (void *)params.get() << "\n"; - params->index_schema_name = std::move(index_schema_name); - params->index_schema = std::move(index_schema); - - VMSDK_RETURN_IF_ERROR( - vmsdk::ParseParamValue(itr, params->parse_vars.query_string)); + RealIndexInterface real_index_interface(index_schema); + parse_vars_.index_interface_ = &real_index_interface; - VMSDK_RETURN_IF_ERROR(PreParseQueryString(*params)); + VMSDK_RETURN_IF_ERROR(vmsdk::ParseParamValue(itr, parse_vars.query_string)); + VMSDK_RETURN_IF_ERROR(PreParseQueryString(*this)); // Ensure that key is first value if it gets included... - CHECK(params->AddRecordAttribute("__key", "__key", - indexes::IndexerType::kNone) == 0); - auto score_sv = vmsdk::ToStringView(params->score_as.get()); - CHECK(params->AddRecordAttribute(score_sv, score_sv, - indexes::IndexerType::kNone) == 1); + CHECK(AddRecordAttribute("__key", "__key", indexes::IndexerType::kNone) == 0); + auto score_sv = vmsdk::ToStringView(score_as.get()); + CHECK(AddRecordAttribute(score_sv, score_sv, indexes::IndexerType::kNone) == + 1); - VMSDK_RETURN_IF_ERROR(parser.Parse(*params, itr, true)); + VMSDK_RETURN_IF_ERROR(parser.Parse(*this, itr, true)); if (itr.DistanceEnd() > 0) { return absl::InvalidArgumentError( absl::StrCat("Unexpected parameter at position ", (itr.Position() + 1), ":", vmsdk::ToStringView(itr.Get().value()))); } - if (params->dialect < 2 || params->dialect > 4) { + if (dialect < 2 || dialect > 4) { return absl::InvalidArgumentError("Only Dialects 2, 3 and 4 are supported"); } - params->limit.number = - std::numeric_limits::max(); // Override default of 10 from - // search + limit.number = std::numeric_limits::max(); // Override default of + // 10 from search - VMSDK_RETURN_IF_ERROR(PostParseQueryString(*params)); - VMSDK_RETURN_IF_ERROR(ManipulateReturnsClause(*params)); + VMSDK_RETURN_IF_ERROR(PostParseQueryString(*this)); + VMSDK_RETURN_IF_ERROR(ManipulateReturnsClause(*this)); - DBG << "At end of parse: " << *params << "\n"; - params->parse_vars.ClearAtEndOfParse(); - return std::move(params); + return absl::OkStatus(); } bool ReplyWithValue(ValkeyModuleCtx *ctx, data_model::AttributeDataType data_type, - std::string_view name, indexes::IndexerType field_type, + std::string_view name, indexes::IndexerType indexer_type, const expr::Value &value, int dialect) { if (value.IsNil()) { return false; + } + if (data_type == data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) { + ValkeyModule_ReplyWithSimpleString(ctx, name.data()); + auto value_sv = value.AsStringView(); + ValkeyModule_ReplyWithStringBuffer(ctx, value_sv.data(), value_sv.size()); } else { - DBG << "ReplyWithValue " << name << "\n"; - if (data_type == data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) { - ValkeyModule_ReplyWithSimpleString(ctx, name.data()); - auto value_sv = value.AsStringView(); - ValkeyModule_ReplyWithStringBuffer(ctx, value_sv.data(), value_sv.size()); - DBG << "HASH: " << name << ":" << value_sv << "\n"; + char double_storage[50]; + std::string_view value_view; + if (name == "$") { + value_view = value.AsStringView(); } else { - char double_storage[50]; - std::string_view value_view; - if (name == "$") { - value_view = value.AsStringView(); - DBG << "Overriding for field name of $ " << int(field_type) << "\n"; - DBG << "Input: " << value.AsStringView() << "\n"; - } else { - switch (field_type) { - case indexes::IndexerType::kTag: - case indexes::IndexerType::kNone: { - value_view = value.AsStringView(); - DBG << "JSON kTag: " << value_view << "\n"; - break; - } - case indexes::IndexerType::kNumeric: { - auto dble = value.AsDouble(); - if (!dble) { - return false; - } - auto double_size = snprintf(double_storage, sizeof(double_storage), - "%.11g", *dble); - value_view = std::string_view(double_storage, double_size); - DBG << "JSON kNumeric:" << value_view << "\n"; - break; + switch (indexer_type) { + case indexes::IndexerType::kTag: + case indexes::IndexerType::kNone: { + value_view = value.AsStringView(); + break; + } + case indexes::IndexerType::kNumeric: { + auto dble = value.AsDouble(); + if (!dble) { + return false; } - default: - DBG << "Unsupported field type for reply: " << int(field_type) - << "\n"; - assert("Unsupported field type" == nullptr); + auto double_size = + snprintf(double_storage, sizeof(double_storage), "%.11g", *dble); + value_view = std::string_view(double_storage, double_size); + break; } - } - ValkeyModule_ReplyWithSimpleString(ctx, name.data()); - if (dialect == 2) { - ValkeyModule_ReplyWithStringBuffer(ctx, value_view.data(), - value_view.size()); - } else { - std::string s; - s = '['; - s += value_view; - s += ']'; - DBG << "Dialect != 2: " << s << "\n"; - ValkeyModule_ReplyWithStringBuffer(ctx, s.data(), s.size()); + default: + CHECK(false) << " Received type " << int(indexer_type); } } + ValkeyModule_ReplyWithSimpleString(ctx, name.data()); + if (dialect == 2) { + ValkeyModule_ReplyWithStringBuffer(ctx, value_view.data(), + value_view.size()); + } else { + std::string s = absl::StrCat("[", value_view, "]"); + ValkeyModule_ReplyWithStringBuffer(ctx, s.data(), s.size()); + } } return true; } @@ -246,48 +193,41 @@ absl::Status SendReplyInner(ValkeyModuleCtx *ctx, // auto data_type = parameters.index_schema->GetAttributeDataType().ToProto(); RecordSet records(¶meters); - // Todo: fix this for (auto &n : neighbors) { for (auto &n : neighbors) { auto rec = std::make_unique(parameters.record_indexes_by_alias_.size()); - DBG << "Neighbor: " << n << " Empty Record:" << *rec << "\n"; if (parameters.load_key) { rec->fields_.at(key_index) = expr::Value(n.external_id.get()->Str()); } - if (/* todo: parameters.addscores_ */ true) { + if (parameters.IsVectorQuery()) { rec->fields_.at(scores_index) = expr::Value(n.distance); } // For the fields that were fetched, stash them into the RecordSet if (n.attribute_contents.has_value() && !parameters.no_content) { for (auto &[name, records_map_value] : *n.attribute_contents) { auto value = vmsdk::ToStringView(records_map_value.value.get()); - size_t record_index; - bool found_index = false; + std::optional record_index; if (auto by_alias = parameters.record_indexes_by_alias_.find(name); by_alias != parameters.record_indexes_by_alias_.end()) { record_index = by_alias->second; - found_index = true; assert(record_index < rec->field_.size()); } else if (auto by_identifier = parameters.record_indexes_by_identifier_.find(name); by_identifier != parameters.record_indexes_by_identifier_.end()) { record_index = by_identifier->second; - found_index = true; assert(record_index < rec->field_.size()); } - if (found_index) { + if (record_index) { // Need to find the field type - indexes::IndexerType field_type = - parameters.record_info_by_index_[record_index].data_type_; - DBG << "Attribute_contents: " << name << " : " << value - << " Index:" << record_index << " FieldType:" << int(field_type) - << "\n"; - switch (field_type) { + indexes::IndexerType indexer_type = + parameters.record_info_by_index_[*record_index].data_type_; + switch (indexer_type) { case indexes::IndexerType::kNumeric: { auto numeric_value = vmsdk::To(value); if (numeric_value.ok()) { - rec->fields_[record_index] = expr::Value(numeric_value.value()); + rec->fields_[*record_index] = + expr::Value(numeric_value.value()); } else { // Skip this field, it contains an invalid number.... // todo Prove that skipping this field is the right thing to @@ -298,24 +238,18 @@ absl::Status SendReplyInner(ValkeyModuleCtx *ctx, default: if (data_type == data_model::AttributeDataType::ATTRIBUTE_DATA_TYPE_HASH) { - rec->fields_[record_index] = expr::Value(value); + rec->fields_[*record_index] = expr::Value(value); } else { auto v = vmsdk::JsonUnquote(value); if (v) { - DBG << "De-quoting:\n" - << value << "\nBecame:\n" - << *v << "\n"; - rec->fields_[record_index] = expr::Value(std::move(*v)); + rec->fields_[*record_index] = expr::Value(std::move(*v)); } else { goto drop_record; } } break; } - // DBG << "After set record is " << *rec << "\n"; } else { - DBG << "Attribute_contents: " << name << " : " << value - << " Extra:\n"; rec->extra_fields_.push_back( std::make_pair(std::string(name), expr::Value(value))); } @@ -324,7 +258,6 @@ absl::Status SendReplyInner(ValkeyModuleCtx *ctx, records.push_back(std::move(rec)); drop_record:; } - DBG << "After Record Fetch\n" << records << "\n"; // // 2. Perform the aggregation stages // @@ -332,7 +265,6 @@ absl::Status SendReplyInner(ValkeyModuleCtx *ctx, // Todo Check for timeout VMSDK_RETURN_IF_ERROR(stage->Execute(records)); } - DBG << ">> Finished stages\n" << records; // // 3. Generate the result @@ -366,18 +298,15 @@ absl::Status SendReplyInner(ValkeyModuleCtx *ctx, array_count += 2; } } - DBG << " (Total length) " << array_count << "\n"; ValkeyModule_ReplySetArrayLength(ctx, array_count); } return absl::OkStatus(); } -void SendAggReply(ValkeyModuleCtx *ctx, - std::deque &neighbors, - AggregateParameters ¶meters) { - auto identifier = - parameters.index_schema->GetIdentifier(parameters.attribute_alias); - auto result = SendReplyInner(ctx, neighbors, parameters); +void AggregateParameters::SendReply(ValkeyModuleCtx *ctx, + std::deque &neighbors) { + auto identifier = index_schema->GetIdentifier(attribute_alias); + auto result = SendReplyInner(ctx, neighbors, *this); if (!result.ok()) { ++Metrics::GetStats().query_failed_requests_cnt; ValkeyModule_ReplyWithError(ctx, result.message().data()); @@ -386,60 +315,11 @@ void SendAggReply(ValkeyModuleCtx *ctx, } // namespace aggregate -CONTROLLED_BOOLEAN(AggForceReplicasOnly, false); - absl::Status FTAggregateCmd(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { - auto status = [&]() { - auto &schema_manager = SchemaManager::Instance(); - VMSDK_ASSIGN_OR_RETURN( - auto parameters, - aggregate::ParseCommand(ctx, argv + 1, argc - 1, schema_manager)); - parameters->index_schema->ProcessMultiQueue(); - bool inside_multi = - (ValkeyModule_GetContextFlags(ctx) & VALKEYMODULE_CTX_FLAGS_MULTI) != 0; - if (ABSL_PREDICT_FALSE(!ValkeySearch::Instance().SupportParallelQueries() || - inside_multi)) { - VMSDK_ASSIGN_OR_RETURN( - auto neighbors, - query::Search(*parameters, query::SearchMode::kLocal)); - SendAggReply(ctx, neighbors, *parameters); - return absl::OkStatus(); - } - vmsdk::BlockedClient blocked_client(ctx, async::Reply, async::Timeout, - async::Free, 0); - blocked_client.MeasureTimeStart(); - auto on_done_callback = [blocked_client = std::move(blocked_client)]( - auto &neighbors, auto parameters) mutable { - auto result = std::make_unique(async::Result{ - .neighbors = std::move(neighbors), - .parameters = std::move(parameters), - }); - blocked_client.SetReplyPrivateData(result.release()); - }; - - if (ValkeySearch::Instance().UsingCoordinator() && - ValkeySearch::Instance().IsCluster() && !parameters->local_only) { - auto mode = /* !vmsdk::IsReadOnly(ctx) ? query::fanout::kPrimaries ? */ - AggForceReplicasOnly.GetValue() - ? query::fanout::FanoutTargetMode::kReplicasOnly - : query::fanout::FanoutTargetMode::kRandom; - auto search_targets = query::fanout::GetSearchTargetsForFanout(ctx, mode); - return query::fanout::PerformSearchFanoutAsync( - ctx, search_targets, - ValkeySearch::Instance().GetCoordinatorClientPool(), - std::move(parameters), ValkeySearch::Instance().GetReaderThreadPool(), - std::move(on_done_callback)); - } - return query::SearchAsync( - std::move(parameters), ValkeySearch::Instance().GetReaderThreadPool(), - std::move(on_done_callback), query::SearchMode::kLocal); - CHECK(false); - }(); - if (!status.ok()) { - ++Metrics::GetStats().query_failed_requests_cnt; - } - return status; + return QueryCommand::Execute( + ctx, argv, argc, + std::unique_ptr(new aggregate::AggregateParameters)); } } // namespace valkey_search diff --git a/src/commands/ft_aggregate.h b/src/commands/ft_aggregate.h index 32a9941df..eadfaa89e 100644 --- a/src/commands/ft_aggregate.h +++ b/src/commands/ft_aggregate.h @@ -8,7 +8,7 @@ #define VALKEYSEARCH_SRC_COMMANDS_FT_AGGREGATE_H #include "absl/status/status.h" -#include "src/commands/ft_aggregate_parser.h" +#include "valkey_module.h" namespace valkey_search { namespace aggregate { @@ -16,11 +16,6 @@ namespace aggregate { absl::Status FTAggregateCmd(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc); -struct AggregateParameters; -void SendAggReply(ValkeyModuleCtx *ctx, - std::deque &neighbors, - AggregateParameters ¶meters); - } // namespace aggregate }; // namespace valkey_search #endif diff --git a/src/commands/ft_aggregate_parser.h b/src/commands/ft_aggregate_parser.h index 6e5962bae..db253ec04 100644 --- a/src/commands/ft_aggregate_parser.h +++ b/src/commands/ft_aggregate_parser.h @@ -10,6 +10,7 @@ #include "absl/container/flat_hash_map.h" #include "absl/container/inlined_vector.h" #include "absl/status/status.h" +#include "src/commands/commands.h" #include "src/expr/expr.h" #include "src/expr/value.h" #include "src/query/search.h" @@ -34,8 +35,12 @@ struct IndexInterface { }; struct AggregateParameters : public expr::Expression::CompileContext, - public query::SearchParameters { + public QueryCommand { ~AggregateParameters() override = default; + AggregateParameters() = default; + absl::Status ParseCommand(vmsdk::ArgsIterator& itr) override; + void SendReply(ValkeyModuleCtx* ctx, + std::deque& neighbors) override; bool loadall_{false}; std::vector loads_; bool load_key{false}; @@ -117,11 +122,6 @@ struct AggregateParameters : public expr::Expression::CompileContext, parse_vars.ClearAtEndOfParse(); } - AggregateParameters(uint64_t timeout, IndexInterface* index_interface) - : query::SearchParameters(timeout, nullptr) { - parse_vars_.index_interface_ = index_interface; - } - friend std::ostream& operator<<(std::ostream& os, const AggregateParameters& agg); }; @@ -267,7 +267,7 @@ class SortBy : public Stage { } }; -absl::StatusOr> ParseAggregateParameters( +absl::StatusOr> ParseAggregateParameters( ValkeyModuleCtx* ctx, ValkeyModuleString** argv, int argc, const SchemaManager& schema_manager); diff --git a/src/commands/ft_debug.cc b/src/commands/ft_debug.cc index c4ad9a6a2..97e6a27cb 100644 --- a/src/commands/ft_debug.cc +++ b/src/commands/ft_debug.cc @@ -9,7 +9,6 @@ #include #include "module_config.h" -#include "src/commands/commands.h" #include "vmsdk/src/command_parser.h" #include "vmsdk/src/debug.h" #include "vmsdk/src/info.h" diff --git a/src/commands/ft_search.cc b/src/commands/ft_search.cc index 0dadfd679..72c9197ad 100644 --- a/src/commands/ft_search.cc +++ b/src/commands/ft_search.cc @@ -5,42 +5,28 @@ * */ -#include "src/commands/ft_search.h" - #include #include #include #include #include -#include #include #include #include -#include "absl/base/optimization.h" #include "absl/log/check.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_format.h" #include "absl/strings/string_view.h" -#include "src/acl.h" #include "src/commands/commands.h" -#include "src/commands/ft_aggregate.h" #include "src/commands/ft_search_parser.h" #include "src/indexes/vector_base.h" #include "src/metrics.h" -#include "src/query/fanout.h" #include "src/query/response_generator.h" #include "src/query/search.h" -#include "src/schema_manager.h" -#include "src/valkey_search.h" -#include "src/valkey_search_options.h" -#include "vmsdk/src/blocked_client.h" -#include "vmsdk/src/debug.h" #include "vmsdk/src/managed_pointers.h" -#include "vmsdk/src/module_config.h" -#include "vmsdk/src/status/status_macros.h" #include "vmsdk/src/type_conversions.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" @@ -172,145 +158,44 @@ void SerializeNonVectorNeighbors(ValkeyModuleCtx *ctx, // 3. Attribute name // 4. The vector value // SendReply respects the Limit, see https://valkey.io/commands/ft.search/ -void SendReply(ValkeyModuleCtx *ctx, std::deque &neighbors, - const query::SearchParameters ¶meters) { - if (!options::GetEnablePartialResults().GetValue() && - parameters.cancellation_token->IsCancelled()) { - ValkeyModule_ReplyWithError(ctx, - "Search operation cancelled due to timeout"); - ++Metrics::GetStats().query_failed_requests_cnt; - return; - } - if (auto agg = dynamic_cast( - const_cast(¶meters))) { - aggregate::SendAggReply(ctx, neighbors, *agg); - return; - } +void SearchCommand::SendReply(ValkeyModuleCtx *ctx, + std::deque &neighbors) { // Increment success counter. ++Metrics::GetStats().query_successful_requests_cnt; // Support non-vector queries: no attribute_alias and k == 0 - if (parameters.IsNonVectorQuery()) { + if (IsNonVectorQuery()) { query::ProcessNonVectorNeighborsForReply( - ctx, parameters.index_schema->GetAttributeDataType(), neighbors, - parameters); - SerializeNonVectorNeighbors(ctx, neighbors, parameters); + ctx, index_schema->GetAttributeDataType(), neighbors, *this); + SerializeNonVectorNeighbors(ctx, neighbors, *this); return; } - if (parameters.limit.first_index >= static_cast(parameters.k) || - parameters.limit.number == 0) { + if (limit.first_index >= static_cast(k) || limit.number == 0) { ValkeyModule_ReplyWithArray(ctx, 1); ValkeyModule_ReplyWithLongLong(ctx, neighbors.size()); return; } - if (parameters.no_content) { - SendReplyNoContent(ctx, neighbors, parameters); + if (no_content) { + SendReplyNoContent(ctx, neighbors, *this); return; } - auto identifier = - parameters.index_schema->GetIdentifier(parameters.attribute_alias); + auto identifier = index_schema->GetIdentifier(attribute_alias); if (!identifier.ok()) { ++Metrics::GetStats().query_failed_requests_cnt; ValkeyModule_ReplyWithError(ctx, identifier.status().message().data()); return; } - query::ProcessNeighborsForReply( - ctx, parameters.index_schema->GetAttributeDataType(), neighbors, - parameters, identifier.value()); + query::ProcessNeighborsForReply(ctx, index_schema->GetAttributeDataType(), + neighbors, *this, identifier.value()); - SerializeNeighbors(ctx, neighbors, parameters); -} - -namespace async { - -int Timeout(ValkeyModuleCtx *ctx, [[maybe_unused]] ValkeyModuleString **argv, - [[maybe_unused]] int argc) { - return ValkeyModule_ReplyWithError( - ctx, "Search operation cancelled due to timeout"); -} - -int Reply(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { - auto *res = - static_cast(ValkeyModule_GetBlockedClientPrivateData(ctx)); - CHECK(res != nullptr); - if (!res->neighbors.ok()) { - ++Metrics::GetStats().query_failed_requests_cnt; - return ValkeyModule_ReplyWithError( - ctx, res->neighbors.status().message().data()); - } - SendReply(ctx, res->neighbors.value(), *res->parameters); - return VALKEYMODULE_OK; + SerializeNeighbors(ctx, neighbors, *this); } -void Free([[maybe_unused]] ValkeyModuleCtx *ctx, void *privdata) { - auto *result = static_cast(privdata); - delete result; -} - -} // namespace async - -CONTROLLED_BOOLEAN(ForceReplicasOnly, false); - absl::Status FTSearchCmd(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, int argc) { - auto status = [&]() -> absl::Status { - auto &schema_manager = SchemaManager::Instance(); - VMSDK_ASSIGN_OR_RETURN( - auto parameters, - ParseVectorSearchParameters(ctx, argv + 1, argc - 1, schema_manager)); - parameters->cancellation_token = - cancel::Make(parameters->timeout_ms, nullptr); - static const auto permissions = - PrefixACLPermissions(kSearchCmdPermissions, kSearchCommand); - VMSDK_RETURN_IF_ERROR(AclPrefixCheck( - ctx, permissions, parameters->index_schema->GetKeyPrefixes())); - - parameters->index_schema->ProcessMultiQueue(); - - const bool inside_multi_exec = vmsdk::MultiOrLua(ctx); - if (ABSL_PREDICT_FALSE(!ValkeySearch::Instance().SupportParallelQueries() || - inside_multi_exec)) { - VMSDK_ASSIGN_OR_RETURN( - auto neighbors, - query::Search(*parameters, query::SearchMode::kLocal)); - SendReply(ctx, neighbors, *parameters); - return absl::OkStatus(); - } - - vmsdk::BlockedClient blocked_client(ctx, async::Reply, async::Timeout, - async::Free, parameters->timeout_ms); - blocked_client.MeasureTimeStart(); - auto on_done_callback = [blocked_client = std::move(blocked_client)]( - auto &neighbors, auto parameters) mutable { - auto result = std::make_unique(async::Result{ - .neighbors = std::move(neighbors), - .parameters = std::move(parameters), - }); - blocked_client.SetReplyPrivateData(result.release()); - }; - - if (ValkeySearch::Instance().UsingCoordinator() && - ValkeySearch::Instance().IsCluster() && !parameters->local_only) { - auto mode = /* !vmsdk::IsReadOnly(ctx) ? query::fanout::kPrimaries ? */ - ForceReplicasOnly.GetValue() - ? query::fanout::FanoutTargetMode::kReplicasOnly - : query::fanout::FanoutTargetMode::kRandom; - auto search_targets = query::fanout::GetSearchTargetsForFanout(ctx, mode); - return query::fanout::PerformSearchFanoutAsync( - ctx, search_targets, - ValkeySearch::Instance().GetCoordinatorClientPool(), - std::move(parameters), ValkeySearch::Instance().GetReaderThreadPool(), - std::move(on_done_callback)); - } - return query::SearchAsync( - std::move(parameters), ValkeySearch::Instance().GetReaderThreadPool(), - std::move(on_done_callback), query::SearchMode::kLocal); - }(); - if (!status.ok()) { - ++Metrics::GetStats().query_failed_requests_cnt; - } - return status; + return QueryCommand::Execute( + ctx, argv, argc, std::unique_ptr(new SearchCommand)); } } // namespace valkey_search diff --git a/src/commands/ft_search.h b/src/commands/ft_search.h index 5abcc1d56..18b5a536b 100644 --- a/src/commands/ft_search.h +++ b/src/commands/ft_search.h @@ -8,36 +8,8 @@ #ifndef VALKEYSEARCH_SRC_COMMANDS_FT_SEARCH_H_ #define VALKEYSEARCH_SRC_COMMANDS_FT_SEARCH_H_ -#include -#include - -#include "absl/status/statusor.h" -#include "src/indexes/vector_base.h" -#include "src/query/search.h" -#include "vmsdk/src/valkey_module_api/valkey_module.h" - namespace valkey_search { class ValkeySearch; -// Declared here to support testing -void SendReply(ValkeyModuleCtx *ctx, std::deque &neighbors, - const query::SearchParameters ¶meters); -namespace async { - -struct Result { - cancel::Token cancellation_token; - absl::StatusOr> neighbors; - std::unique_ptr parameters; -}; - -int Reply(ValkeyModuleCtx *ctx, [[maybe_unused]] ValkeyModuleString **argv, - [[maybe_unused]] int argc); - -int Timeout(ValkeyModuleCtx *ctx, [[maybe_unused]] ValkeyModuleString **argv, - [[maybe_unused]] int argc); - -void Free(ValkeyModuleCtx * /*ctx*/, void *privdata); - -} // namespace async } // namespace valkey_search #endif // VALKEYSEARCH_SRC_COMMANDS_FT_SEARCH_H_ diff --git a/src/commands/ft_search_parser.cc b/src/commands/ft_search_parser.cc index 613379d66..040ee4e42 100644 --- a/src/commands/ft_search_parser.cc +++ b/src/commands/ft_search_parser.cc @@ -29,14 +29,12 @@ #include "src/indexes/index_base.h" #include "src/metrics.h" #include "src/query/search.h" -#include "src/schema_manager.h" #include "src/valkey_search_options.h" #include "vmsdk/src/command_parser.h" #include "vmsdk/src/managed_pointers.h" #include "vmsdk/src/module_config.h" #include "vmsdk/src/status/status_macros.h" #include "vmsdk/src/type_conversions.h" -#include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search { @@ -445,30 +443,17 @@ absl::Status PostParseQueryString(query::SearchParameters ¶meters) { return absl::OkStatus(); } -absl::StatusOr> -ParseVectorSearchParameters(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, - int argc, const SchemaManager &schema_manager) { - vmsdk::ArgsIterator itr{argv, argc}; - auto parameters = std::make_unique( - options::GetDefaultTimeoutMs().GetValue(), nullptr); - VMSDK_RETURN_IF_ERROR( - vmsdk::ParseParamValue(itr, parameters->index_schema_name)); - VMSDK_ASSIGN_OR_RETURN( - parameters->index_schema, - SchemaManager::Instance().GetIndexSchema(ValkeyModule_GetSelectedDb(ctx), - parameters->index_schema_name)); - VMSDK_RETURN_IF_ERROR( - vmsdk::ParseParamValue(itr, parameters->parse_vars.query_string)); - VMSDK_RETURN_IF_ERROR(SearchParser.Parse(*parameters, itr)); +absl::Status SearchCommand::ParseCommand(vmsdk::ArgsIterator &itr) { + VMSDK_RETURN_IF_ERROR(SearchParser.Parse(*this, itr)); if (itr.DistanceEnd() > 0) { return absl::InvalidArgumentError( absl::StrCat("Unexpected parameter at position ", (itr.Position() + 1), ":", vmsdk::ToStringView(itr.Get().value()))); } - VMSDK_RETURN_IF_ERROR(PreParseQueryString(*parameters)); - VMSDK_RETURN_IF_ERROR(PostParseQueryString(*parameters)); - VMSDK_RETURN_IF_ERROR(Verify(*parameters)); - parameters->parse_vars.ClearAtEndOfParse(); - return parameters; + VMSDK_RETURN_IF_ERROR(PreParseQueryString(*this)); + VMSDK_RETURN_IF_ERROR(PostParseQueryString(*this)); + VMSDK_RETURN_IF_ERROR(Verify(*this)); + return absl::OkStatus(); } + } // namespace valkey_search diff --git a/src/commands/ft_search_parser.h b/src/commands/ft_search_parser.h index c1acbbe9f..e8bf69332 100644 --- a/src/commands/ft_search_parser.h +++ b/src/commands/ft_search_parser.h @@ -8,13 +8,10 @@ #ifndef VALKEYSEARCH_SRC_COMMANDS_FT_SEARCH_PARSER_H_ #define VALKEYSEARCH_SRC_COMMANDS_FT_SEARCH_PARSER_H_ -#include #include -#include -#include "absl/status/statusor.h" +#include "src/commands/commands.h" #include "src/query/search.h" -#include "src/schema_manager.h" #include "vmsdk/src/valkey_module_api/valkey_module.h" namespace valkey_search { @@ -30,9 +27,14 @@ struct LimitParameter { absl::Status PreParseQueryString(query::SearchParameters ¶meters); absl::Status PostParseQueryString(query::SearchParameters ¶meters); -absl::StatusOr> -ParseVectorSearchParameters(ValkeyModuleCtx *ctx, ValkeyModuleString **argv, - int argc, const SchemaManager &schema_manager); +// +// Data Unique to the FT.SEARCH command +// +struct SearchCommand : public QueryCommand { + absl::Status ParseCommand(vmsdk::ArgsIterator &itr) override; + void SendReply(ValkeyModuleCtx *ctx, + std::deque &neighbors) override; +}; } // namespace valkey_search #endif // VALKEYSEARCH_SRC_COMMANDS_FT_SEARCH_PARSER_H_ diff --git a/testing/CMakeLists.txt b/testing/CMakeLists.txt index e47bfc9b6..d39e84f38 100644 --- a/testing/CMakeLists.txt +++ b/testing/CMakeLists.txt @@ -54,6 +54,8 @@ target_link_libraries(testing_common_coordinator INTERFACE client_pool) # 1. Commands Test Suite - consolidates FT command related tests set(COMMANDS_TEST_SOURCES + ${CMAKE_CURRENT_LIST_DIR}/ft_aggregate_exec_test.cc + ${CMAKE_CURRENT_LIST_DIR}/ft_aggregate_parser_test.cc ${CMAKE_CURRENT_LIST_DIR}/ft_create_parser_test.cc ${CMAKE_CURRENT_LIST_DIR}/ft_search_parser_test.cc ${CMAKE_CURRENT_LIST_DIR}/ft_search_test.cc diff --git a/testing/common.h b/testing/common.h index 3b9ca012c..36139abf6 100644 --- a/testing/common.h +++ b/testing/common.h @@ -40,11 +40,8 @@ #include "src/server_events.h" #include "src/utils/string_interning.h" #include "src/valkey_search.h" -#include "src/valkey_search_options.h" #include "src/vector_externalizer.h" -#include "third_party/hnswlib/iostream.h" #include "vmsdk/src/managed_pointers.h" -#include "vmsdk/src/module_config.h" #include "vmsdk/src/status/status_macros.h" #include "vmsdk/src/testing_infra/module.h" #include "vmsdk/src/testing_infra/utils.h" diff --git a/testing/ft_aggregate_exec_test.cc b/testing/ft_aggregate_exec_test.cc index 809d090ab..fef9df76d 100644 --- a/testing/ft_aggregate_exec_test.cc +++ b/testing/ft_aggregate_exec_test.cc @@ -61,22 +61,23 @@ static RecordSet MakeData(size_t m) { return result; } -struct AggregateExecTest : public vmsdk::RedisTest { +struct AggregateExecTest : public vmsdk::ValkeyTest { void SetUp() override { fakeIndex.fields_ = { {"n1", indexes::IndexerType::kNumeric}, {"n2", indexes::IndexerType::kNumeric}, }; - vmsdk::RedisTest::SetUp(); + vmsdk::ValkeyTest::SetUp(); } - void TearDown() override { vmsdk::RedisTest::TearDown(); } + void TearDown() override { vmsdk::ValkeyTest::TearDown(); } FakeIndexInterface fakeIndex; std::unique_ptr MakeStages(absl::string_view test) { - auto argv = vmsdk::ToRedisStringVector(test); + auto argv = vmsdk::ToValkeyStringVector(test); vmsdk::ArgsIterator itr(argv.data(), argv.size()); - auto params = std::make_unique(&fakeIndex); + auto params = std::make_unique(); + params->parse_vars_.index_interface_ = &fakeIndex; EXPECT_EQ( params->AddRecordAttribute("n1", "n1", indexes::IndexerType::kNumeric), 0); @@ -245,7 +246,7 @@ TEST_F(AggregateExecTest, ReducerTest) { {"groupby 1 @n2 reduce sum 1 @n1", 4, {6}}, {"groupby 1 @n2 reduce stddev 1 @n1", 4, {1.2909944487358056}}, {"groupby 1 @n2 reduce count_distinct 1 @n1", 4, {4}}, - }; + {"groupby 1 @n2 reduce avg 1 @n1", 4, {1.5}}}; for (auto& tc : testcases) { std::cerr << "GroupTest: " << tc.text_ << "\n"; auto param = MakeStages(tc.text_); diff --git a/testing/ft_aggregate_parser_test.cc b/testing/ft_aggregate_parser_test.cc index d19696571..47da7e395 100644 --- a/testing/ft_aggregate_parser_test.cc +++ b/testing/ft_aggregate_parser_test.cc @@ -51,15 +51,15 @@ struct FakeIndexInterface : public IndexInterface { } }; -struct AggregateTest : public vmsdk::RedisTest { +struct AggregateTest : public vmsdk::ValkeyTest { void SetUp() override { fake_index.fields_ = { {"n1", indexes::IndexerType::kNumeric}, {"n2", indexes::IndexerType::kNumeric}, }; - vmsdk::RedisTest::SetUp(); + vmsdk::ValkeyTest::SetUp(); } - void TearDown() override { vmsdk::RedisTest::TearDown(); } + void TearDown() override { vmsdk::ValkeyTest::TearDown(); } FakeIndexInterface fake_index; }; @@ -105,10 +105,12 @@ static void DoPrefaceTestCase(FakeIndexInterface *fake_index, std::string test, DialectTestValue dialect_test, LoadsTestValue loads_test) { std::cerr << "Running test: '" << test << "'\n"; - auto argv = vmsdk::ToRedisStringVector(test); + auto argv = vmsdk::ToValkeyStringVector(test); vmsdk::ArgsIterator itr(argv.data(), argv.size()); - AggregateParameters params(fake_index); + AggregateParameters params; + params.timeout_ms = query::kTimeoutMS; + params.parse_vars_.index_interface_ = fake_index; auto parser = CreateAggregateParser(); @@ -207,7 +209,7 @@ static std::vector TestStages{ {"apply x", nullptr}, {"apply @n1", nullptr}, {"apply @n1 xx", nullptr}, - {"APPLY @n1 as ferd", "APPLY: ferd := @n1"}, + {"APPLY @n1 as freddy", "APPLY: freddy := @n1"}, }; static void DoStageTest(FakeIndexInterface *fake_index, @@ -220,10 +222,12 @@ static void DoStageTest(FakeIndexInterface *fake_index, any_bad |= TestStages[ix].stage_out_ == nullptr; } std::cout << "Doing case " << text << "\n"; - auto argv = vmsdk::ToRedisStringVector(text); + auto argv = vmsdk::ToValkeyStringVector(text); vmsdk::ArgsIterator itr(argv.data(), argv.size()); - AggregateParameters params(fake_index); + AggregateParameters params; + params.timeout_ms = 0; + params.parse_vars_.index_interface_ = fake_index; auto parser = CreateAggregateParser(); auto result = parser.Parse(params, itr); diff --git a/testing/ft_search_parser_test.cc b/testing/ft_search_parser_test.cc index 2395db5db..6680c3d2f 100644 --- a/testing/ft_search_parser_test.cc +++ b/testing/ft_search_parser_test.cc @@ -197,12 +197,30 @@ void DoVectorSearchParserTest(const FTSearchParserTestCase &test_case, std::cerr << "Executing cmd: "; for (auto &a : args) { - std::cerr << vmsdk::ToStringView(a) << " "; + std::cerr << "'" << vmsdk::ToStringView(a) << "' "; } std::cerr << "\n"; - auto search_params = ParseVectorSearchParameters(&fake_ctx, &args[0], - args.size(), schema_manager); + // Repro semantics of command startup + vmsdk::ArgsIterator itr{&args[0], int(args.size())}; + absl::StatusOr> search_params( + std::make_unique()); + (*search_params)->timeout_ms = 50000; + (*search_params)->index_schema_name = vmsdk::ToStringView(*itr.PopNext()); + (*search_params)->parse_vars.query_string = + vmsdk::ToStringView(*itr.PopNext()); + auto this_index_schema = + schema_manager.GetIndexSchema(0, (*search_params)->index_schema_name); + if (!this_index_schema.ok()) { + search_params = this_index_schema.status(); + } else { + (*search_params)->index_schema = *this_index_schema; + auto sts = (*search_params)->ParseCommand(itr); + if (!sts.ok()) { + search_params = sts; + } + } + bool expected_success = dialect_expected_success && limit_expected_success && test_case.success && !add_end_unexpected_param && timeout_expected_success; diff --git a/testing/ft_search_test.cc b/testing/ft_search_test.cc index 901f2be06..4554293f4 100644 --- a/testing/ft_search_test.cc +++ b/testing/ft_search_test.cc @@ -176,7 +176,8 @@ void SendReplyTest::DoSendReplyTest( for (const auto &neighbor : input.neighbors) { neighbors.push_back(ToIndexesNeighbor(neighbor)); } - auto parameters = std::make_unique(10000, nullptr); + auto parameters = std::make_unique(); + parameters->timeout_ms = 10000; parameters->index_schema = test_index_schema; parameters->attribute_alias = attribute_alias; parameters->score_as = vmsdk::MakeUniqueValkeyString(score_as); @@ -187,7 +188,7 @@ void SendReplyTest::DoSendReplyTest( parameters->return_attributes.push_back( ToReturnAttribute(return_attribute)); } - SendReply(&fake_ctx, neighbors, *parameters); + parameters->SendReply(&fake_ctx, neighbors); EXPECT_EQ(ParseRespReply(fake_ctx.reply_capture.GetReply()), expected_output); } diff --git a/testing/valkey_search_test.cc b/testing/valkey_search_test.cc index 4ec636926..72d700ba7 100644 --- a/testing/valkey_search_test.cc +++ b/testing/valkey_search_test.cc @@ -19,14 +19,13 @@ #include "absl/time/time.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "src/attribute_data_type.h" #include "src/coordinator/metadata_manager.h" #include "src/index_schema.h" #include "src/metrics.h" -#include "src/schema_manager.h" #include "src/utils/string_interning.h" #include "testing/common.h" #include "testing/coordinator/common.h" +#include "valkey_search_options.h" #include "vmsdk/src/memory_allocation.h" #include "vmsdk/src/module.h" #include "vmsdk/src/testing_infra/module.h" diff --git a/vmsdk/src/type_conversions.h b/vmsdk/src/type_conversions.h index 3f65c90d9..d7555038a 100644 --- a/vmsdk/src/type_conversions.h +++ b/vmsdk/src/type_conversions.h @@ -115,7 +115,7 @@ inline absl::StatusOr To(absl::string_view str) { return ToNumeric(str); } -#if defined(__clang__) +#if defined(__clang__) && !defined(RunningClangd) template <> inline absl::StatusOr To(absl::string_view str) { return ToNumeric(str);