Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,9 @@ char *getRedisConfigValue(RedisModuleCtx *ctx, const char* confName);
#define BM25STD_TANH_FACTOR_MIN 1
#define DEFAULT_BG_OOM_PAUSE_TIME_BEFOR_RETRY 5
#define DEFAULT_INDEXER_YIELD_EVERY_OPS 1000
#define DEFAULT_SHARD_WINDOW_RATIO 1.0
#define MIN_SHARD_WINDOW_RATIO 0.0 // Exclusive minimum (must be > 0.0)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Documentation]

This comment is slightly confusing. While technically correct that it's an exclusive minimum, it could be clearer. Consider rephrasing to directly state the requirement for ratio.

Copy Context

#define MAX_SHARD_WINDOW_RATIO 1.0

// default configuration
#define RS_DEFAULT_CONFIG { \
Expand Down
27 changes: 24 additions & 3 deletions src/coord/dist_aggregate.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
#include "util/timeout.h"
#include "resp3.h"
#include "coord/config.h"
#include "config.h"
#include "dist_profile.h"
#include "shard_window_ratio.h"
#include "util/misc.h"
#include "aggregate/aggregate_debug.h"
#include "info/info_redis/threads/current_thread.h"
Expand Down Expand Up @@ -529,7 +531,7 @@ static RPNet *RPNet_New(const MRCommand *cmd) {
}

static void buildMRCommand(RedisModuleString **argv, int argc, int profileArgs,
AREQDIST_UpstreamInfo *us, MRCommand *xcmd, IndexSpec *sp) {
AREQDIST_UpstreamInfo *us, MRCommand *xcmd, IndexSpec *sp, specialCaseCtx *knnCtx) {
// We need to prepend the array with the command, index, and query that
// we want to use.
const char **tmparr = array_new(const char *, us->nserialized);
Expand Down Expand Up @@ -608,6 +610,22 @@ static void buildMRCommand(RedisModuleString **argv, int argc, int profileArgs,
}
}

// Handle KNN with shard ratio optimization for both multi-shard and standalone
if (knnCtx) {
KNNVectorQuery *knn_query = &knnCtx->knn.queryNode->vn.vq->knn;
double ratio = knn_query->shardWindowRatio;

if (ratio < MAX_SHARD_WINDOW_RATIO) {
// Apply optimization only if ratio is valid and < 1.0 (ratio = 1.0 means no optimization)
// Calculate effective K based on deployment mode
size_t numShards = GetNumShards_UnSafe();
size_t effectiveK = calculateEffectiveK(knn_query->k, ratio, numShards);

// Modify the command to replace KNN k (shards will ignore $SHARD_K_RATIO)
modifyKNNCommand(xcmd, 2 + profileArgs, effectiveK, knnCtx->knn.queryNode->vn.vq);
}
}
Comment on lines +613 to +627

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[BestPractice]

This block of logic for handling the shard window ratio is very similar to the logic in src/module.c inside the prepareCommand function (lines 3298-3318). To avoid code duplication and improve maintainability, consider refactoring this into a shared helper function.

The implementation in module.c is slightly more optimized as it includes a check if (knn_query->k == effectiveK) break; to avoid unnecessary command modification. This check should be included in the shared function.

Copy Context


// check for timeout argument and append it to the command.
// If TIMEOUT exists, it was already validated at AREQ_Compile.
int timeout_index = RMUtil_ArgIndex("TIMEOUT", argv + 3 + profileArgs, argc - 4 - profileArgs);
Expand Down Expand Up @@ -713,11 +731,14 @@ static int prepareForExecution(AREQ *r, RedisModuleCtx *ctx, RedisModuleString *
r->profile = printAggProfile;

unsigned int dialect = r->reqConfig.dialectVersion;
specialCaseCtx *knnCtx = NULL;

if(dialect >= 2) {
// Check if we have KNN in the query string, and if so, parse the query string to see if it is
// a KNN section in the query. IN that case, we treat this as a SORTBY+LIMIT step.
if(strcasestr(r->query, "KNN")) {
specialCaseCtx *knnCtx = prepareOptionalTopKCase(r->query, argv, argc, dialect, status);
// For distributed aggregation, command type detection is automatic
knnCtx = prepareOptionalTopKCase(r->query, argv, argc, dialect, status);
*knnCtx_ptr = knnCtx;
if (QueryError_HasError(status)) {
return REDISMODULE_ERR;
Expand All @@ -739,7 +760,7 @@ static int prepareForExecution(AREQ *r, RedisModuleCtx *ctx, RedisModuleString *

// Construct the command string
MRCommand xcmd;
buildMRCommand(argv , argc, profileArgs, &us, &xcmd, sp);
buildMRCommand(argv , argc, profileArgs, &us, &xcmd, sp, knnCtx);
xcmd.protocol = is_resp3(ctx) ? 3 : 2;
xcmd.forCursor = r->reqflags & QEXEC_F_IS_CURSOR;
xcmd.forProfiling = IsProfile(r);
Expand Down
40 changes: 40 additions & 0 deletions src/coord/rmr/command.c
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,46 @@ void MRCommand_ReplaceArg(MRCommand *cmd, int index, const char *newArg, size_t
MRCommand_ReplaceArgNoDup(cmd, index, news, len);
}

void MRCommand_ReplaceArgSubstring(MRCommand *cmd, int index, size_t pos, size_t oldSubStringLen, const char *newStr, size_t newLen) {
RS_LOG_ASSERT_FMT(index >= 0 && index < cmd->num, "Invalid index %d. Command has %d arguments", index, cmd->num);

char *oldArg = cmd->strs[index];
// Get full argument length
size_t oldArgLen = cmd->lens[index];

// Validate position and length
RS_LOG_ASSERT_FMT(pos + oldSubStringLen <= oldArgLen, "Invalid position %zu. Argument length is %zu", pos, oldArgLen);

// Calculate new total length
size_t newArgLen = oldArgLen - oldSubStringLen + newLen;

// OPTIMIZATION: For query string literals, pad with spaces instead of moving memory
if (newLen <= oldSubStringLen) {
// Copy new string
memcpy(oldArg + pos, newStr, newLen);

// Pad remaining space with spaces (no memmove needed)
memset(oldArg + pos + newLen, ' ', oldSubStringLen - newLen);

// No length change needed - argument stays same size
return;
}

// Fallback: Allocate new string for longer replacements
char *newArg = rm_malloc(newArgLen + 1);

// Copy parts: [before] + [new] + [after]
memcpy(newArg, oldArg, pos); // Copy before
memcpy(newArg + pos, newStr, newLen); // Copy new substring
memcpy(newArg + pos + newLen, oldArg + pos + oldSubStringLen, // Copy after
oldArgLen - pos - oldSubStringLen);

newArg[newArgLen] = '\0';

// Replace the argument
MRCommand_ReplaceArgNoDup(cmd, index, newArg, newArgLen);
}

// Should only be relevant for _FT.ADD, _FT.GET, _FT.DEL,
// and _FT.SUG* commands
int MRCommand_GetShardingKey(const MRCommand *cmd) {
Expand Down
13 changes: 13 additions & 0 deletions src/coord/rmr/command.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ void MRCommand_SetPrefix(MRCommand *cmd, const char *newPrefix);
void MRCommand_ReplaceArg(MRCommand *cmd, int index, const char *newArg, size_t len);
void MRCommand_ReplaceArgNoDup(MRCommand *cmd, int index, const char *newArg, size_t len);

/** Replace a substring within an argument at a specific position
* OPTIMIZATION: Avoids reallocation when new string is same/shorter length.
* Instead, pads with spaces.
*
* @param cmd - Command structure containing the arguments
* @param index - Index of the argument to modify
* @param pos - Starting position within the argument string
* @param oldSubStringLen - Length of the substring to replace
* @param newStr - New string to insert
* @param newLen - Length of the new string
*/
void MRCommand_ReplaceArgSubstring(MRCommand *cmd, int index, size_t pos, size_t oldSubStringLen, const char *newStr, size_t newLen);

void MRCommand_WriteTaggedKey(MRCommand *cmd, int index, const char *newarg, const char *part,
size_t n);

Expand Down
41 changes: 41 additions & 0 deletions src/coord/special_case_ctx.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Copyright (c) 2006-Present, Redis Ltd.
* All rights reserved.
*
* Licensed under your choice of the Redis Source Available License 2.0
* (RSALv2); or (b) the Server Side Public License v1 (SSPLv1); or (c) the
* GNU Affero General Public License v3 (AGPLv3).
*/
#pragma once

#include "util/heap.h"
#include "query_node.h"

typedef enum {
SPECIAL_CASE_NONE,
SPECIAL_CASE_KNN,
SPECIAL_CASE_SORTBY
} searchRequestSpecialCase;

typedef struct {
size_t k; // K value TODO: consider remove from here, its in querynode
const char* fieldName; // Field name
bool shouldSort; // Should run presort before the coordinator sort
size_t offset; // Reply offset
heap_t *pq; // Priority queue
QueryNode* queryNode; // Query node
} knnContext;

typedef struct {
const char* sortKey; // SortKey name;
bool asc; // Sort order ASC/DESC
size_t offset; // SortKey reply offset
} sortbyContext;

typedef struct {
union {
knnContext knn;
sortbyContext sortby;
};
searchRequestSpecialCase specialCaseType;
} specialCaseCtx;
27 changes: 26 additions & 1 deletion src/module.c
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
#include "reply.h"
#include "resp3.h"
#include "coord/rmr/rmr.h"
#include "shard_window_ratio.h"

#include "hiredis/async.h"
#include "coord/rmr/reply.h"
#include "coord/rmr/redis_cluster.h"
Expand Down Expand Up @@ -2104,7 +2106,7 @@ searchResult *newResult_resp3(searchResult *cached, MRReply *results, int j, sea
}

MRReply *result_id = MRReply_MapElement(result_j, "id");
if (!result_id || !MRReply_Type(result_id) == MR_REPLY_STRING) {
if (!result_id || MRReply_Type(result_id) != MR_REPLY_STRING) {
// We crash in development env, and return NULL (such that an error is raised)
// in production.
RS_LOG_ASSERT_FMT(false, "Expected id %d to exist, and be a string", j);
Expand Down Expand Up @@ -2892,6 +2894,7 @@ static int searchResultReducer(struct MRCtx *mc, int count, MRReply **replies) {
if (rCtx.reduceSpecialCaseCtxKnn &&
rCtx.reduceSpecialCaseCtxKnn->knn.pq) {
heap_destroy(rCtx.reduceSpecialCaseCtxKnn->knn.pq);
rCtx.reduceSpecialCaseCtxKnn->knn.pq = NULL;
}

RedisModule_BlockedClientMeasureTimeEnd(bc);
Expand Down Expand Up @@ -3292,6 +3295,28 @@ static int prepareCommand(MRCommand *cmd, searchRequestCtx *req, RedisModuleBloc

cmd->protocol = protocol;

// Handle KNN with shard ratio optimization for both multi-shard and standalone
if (req->specialCases) {
for (size_t i = 0; i < array_len(req->specialCases); ++i) {
if (req->specialCases[i]->specialCaseType == SPECIAL_CASE_KNN) {
specialCaseCtx* knnCtx = req->specialCases[i];
KNNVectorQuery *knn_query = &knnCtx->knn.queryNode->vn.vq->knn;
double ratio = knn_query->shardWindowRatio;

// Apply optimization only if ratio is valid and < 1.0 (ratio = 1.0 means no optimization)
if (ratio < MAX_SHARD_WINDOW_RATIO) {
// Calculate effective K based on deployment mode
size_t effectiveK = calculateEffectiveK(knn_query->k, ratio, NumShards);
// No modification needed if K values are the same
if (knn_query->k == effectiveK) break;
// Modify the command to replace KNN k (shards will ignore $SHARD_K_RATIO)
modifyKNNCommand(cmd, 2 + req->profileArgs, effectiveK, knnCtx->knn.queryNode->vn.vq);
}
break; // Only handle KNN context
}
}
}

// replace the LIMIT {offset} {limit} with LIMIT 0 {limit}, because we need all top N to merge
int limitIndex = RMUtil_ArgExists("LIMIT", argv, argc, 3);
if (limitIndex && req->limit > 0 && limitIndex < argc - 2) {
Expand Down
31 changes: 2 additions & 29 deletions src/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <query_node.h>
#include <coord/rmr/reply.h>
#include <util/heap.h>
#include "shard_window_ratio.h"
#include "coord/special_case_ctx.h"

// Hack to support Alpine Linux 3 where __STRING is not defined
#if !defined(__GLIBC__) && !defined(__STRING)
Expand Down Expand Up @@ -77,35 +79,6 @@ do { \
return REDISMODULE_ERR; \
}

typedef enum {
SPECIAL_CASE_NONE,
SPECIAL_CASE_KNN,
SPECIAL_CASE_SORTBY
} searchRequestSpecialCase;

typedef struct {
size_t k; // K value
const char* fieldName; // Field name
bool shouldSort; // Should run presort before the coordinator sort
size_t offset; // Reply offset
heap_t *pq; // Priority queue
QueryNode* queryNode; // Query node
} knnContext;

typedef struct {
const char* sortKey; // SortKey name;
bool asc; // Sort order ASC/DESC
size_t offset; // SortKey reply offset
} sortbyContext;

typedef struct {
union {
knnContext knn;
sortbyContext sortby;
};
searchRequestSpecialCase specialCaseType;
} specialCaseCtx;

typedef struct {
char *queryString;
long long offset;
Expand Down
49 changes: 43 additions & 6 deletions src/query.c
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,19 @@ QueryNode *NewVectorNode_WithParams(struct QueryParseCtx *q, VectorQueryType typ
QueryNode_InitParams(ret, 2);
QueryNode_SetParam(q, &ret->params[0], &vq->knn.vector, &vq->knn.vecLen, vec);
QueryNode_SetParam(q, &ret->params[1], &vq->knn.k, NULL, value);
vq->knn.shardWindowRatio = DEFAULT_SHARD_WINDOW_RATIO;

// Save K position so it can be modified later in the shard command.
// NOTE: If k is given as a *parameter*:
// 1. value->pos: position of "$"
vq->knn.k_token_pos = value->pos;
// 2. value->len: length of the parameter name (e.g. $k -> len=1, $k_meow -> len=6)
// So we need to include the '$' in the token length.
if (value->type == QT_PARAM_SIZE) {
vq->knn.k_token_len = value->len + 1;
} else { // k is literal
vq->knn.k_token_len = value->len;
}
break;
case VECSIM_QT_RANGE:
QueryNode_InitParams(ret, 2);
Expand Down Expand Up @@ -2144,13 +2157,37 @@ int QueryNode_ForEach(QueryNode *q, QueryNode_ForEachCallback callback, void *ct
return retVal;
}

static int ValidateShardKRatio(const char *value, double *ratio, QueryError *status) {
if (!ParseDouble(value, ratio, 1)) {
QueryError_SetWithUserDataFmt(status, QUERY_EINVAL,
"Invalid shard k ratio value", " '%s'", value);
return 0;
}

if (*ratio <= MIN_SHARD_WINDOW_RATIO || *ratio > MAX_SHARD_WINDOW_RATIO) {
QueryError_SetWithoutUserDataFmt(status, QUERY_EINVAL,
"Invalid shard k ratio value: Shard k ratio must be greater than %g and at most %g (got %g)",
MIN_SHARD_WINDOW_RATIO, MAX_SHARD_WINDOW_RATIO, *ratio);
return 0;
}

return 1;
}

// Convert the query attribute into a raw vector param to be resolved by the vector iterator
// down the road. return 0 in case of an unrecognized parameter.
static int QueryVectorNode_ApplyAttribute(VectorQuery *vq, QueryAttribute *attr) {
if (STR_EQCASE(attr->name, attr->namelen, VECSIM_EFRUNTIME) ||
STR_EQCASE(attr->name, attr->namelen, VECSIM_EPSILON) ||
STR_EQCASE(attr->name, attr->namelen, VECSIM_HYBRID_POLICY) ||
STR_EQCASE(attr->name, attr->namelen, VECSIM_BATCH_SIZE)) {
static int QueryVectorNode_ApplyAttribute(VectorQuery *vq, QueryAttribute *attr, QueryError *status) {
if (STR_EQCASE(attr->name, attr->namelen, SHARD_K_RATIO_ATTR)) {
double ratio;
if (!ValidateShardKRatio(attr->value, &ratio, status)) {
return 0;
}
vq->knn.shardWindowRatio = ratio;
return 1;
} else if (STR_EQCASE(attr->name, attr->namelen, VECSIM_EFRUNTIME) ||
STR_EQCASE(attr->name, attr->namelen, VECSIM_EPSILON) ||
STR_EQCASE(attr->name, attr->namelen, VECSIM_HYBRID_POLICY) ||
STR_EQCASE(attr->name, attr->namelen, VECSIM_BATCH_SIZE)) {
// Move ownership on the value string, so it won't get freed when releasing the QueryAttribute.
// The name string was not copied by the parser (unlike the value) - so we copy and save it.
VecSimRawParam param = (VecSimRawParam){ .name = rm_strndup(attr->name, attr->namelen),
Expand Down Expand Up @@ -2234,7 +2271,7 @@ static int QueryNode_ApplyAttribute(QueryNode *qn, QueryAttribute *attr, QueryEr
res = 1;

} else if (qn->type == QN_VECTOR) {
res = QueryVectorNode_ApplyAttribute(qn->vn.vq, attr);
res = QueryVectorNode_ApplyAttribute(qn->vn.vq, attr, status);
}

if (!res) {
Expand Down
1 change: 1 addition & 0 deletions src/query_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ typedef struct {
#define INORDER_ATTR "inorder"
#define WEIGHT_ATTR "weight"
#define PHONETIC_ATTR "phonetic"
#define SHARD_K_RATIO_ATTR "shard_k_ratio"


/* Various modifiers and options that can apply to the entire query or any sub-query of it */
Expand Down
Loading