From e34429e45511f7adbd8c2a3cc670394f87e9ee09 Mon Sep 17 00:00:00 2001 From: CahidArda Date: Thu, 26 Dec 2024 09:30:08 +0300 Subject: [PATCH] feat: add hybrid search --- src/commands/client/fetch/index.test.ts | 104 ++++++++++ src/commands/client/query/index.ts | 1 + .../client/query/query-many/index.test.ts | 61 +++++- .../client/query/query-single/index.test.ts | 178 +++++++++++++++++- .../client/query/query-single/index.ts | 3 + src/commands/client/query/types.ts | 126 ++++++++++++- src/commands/client/range/index.test.ts | 126 +++++++++++++ .../client/resumable-query/index.test.ts | 136 ++++++++++++- src/commands/client/types.ts | 3 + src/commands/client/update/index.test.ts | 94 ++++++++- src/commands/client/update/index.ts | 18 +- src/commands/client/upsert/index.test.ts | 48 ++++- src/commands/client/upsert/index.ts | 8 +- src/utils/test-utils.ts | 89 ++++++++- 14 files changed, 978 insertions(+), 17 deletions(-) diff --git a/src/commands/client/fetch/index.test.ts b/src/commands/client/fetch/index.test.ts index 8a3ce1a..08ea0e0 100644 --- a/src/commands/client/fetch/index.test.ts +++ b/src/commands/client/fetch/index.test.ts @@ -4,6 +4,8 @@ import { Index, awaitUntilIndexed, newHttpClient, + populateHybridIndex, + populateSparseIndex, randomID, range, resetIndexes, @@ -69,6 +71,14 @@ describe("FETCH with Index Client", () => { token: process.env.UPSTASH_VECTOR_REST_TOKEN!, url: process.env.UPSTASH_VECTOR_REST_URL!, }); + const sparseIndex = new Index({ + token: process.env.SPARSE_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.SPARSE_UPSTASH_VECTOR_REST_URL!, + }); + const hybridIndex = new Index({ + token: process.env.HYBRID_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.HYBRID_UPSTASH_VECTOR_REST_URL!, + }); test("should fetch array of records by IDs succesfully", async () => { const randomizedData = Array.from({ length: 20 }) @@ -157,4 +167,98 @@ describe("FETCH with Index Client", () => { const { data: _data, ...mockDataWithoutData } = mockData; expect(fetchWithID).toEqual([mockDataWithoutData]); }); + + test("should fetch from sparse", async () => { + const namespace = "fetch-hybrid"; + await populateSparseIndex(sparseIndex, namespace); + + const result = await sparseIndex.fetch(["id0", "id1", "id2", "id3"], { + includeVectors: true, + includeMetadata: true, + includeData: true, + namespace, + }); + + expect(result).toEqual([ + { + id: "id0", + metadata: undefined, + data: undefined, + vector: undefined, + sparseVector: [ + [0, 1], + [0.1, 0.2], + ], + }, + { + id: "id1", + metadata: { key: "value" }, + data: undefined, + vector: undefined, + sparseVector: [ + [0, 1], + [0.2, 0.3], + ], + }, + { + id: "id2", + metadata: { key: "value" }, + data: "data", + vector: undefined, + sparseVector: [ + [0, 1], + [0.3, 0.4], + ], + }, + // @ts-expect-error checking an index that doesn't exist + undefined, + ]); + }); + + test("should fetch from hybrid", async () => { + const namespace = "fetch-hybrid"; + await populateHybridIndex(hybridIndex, namespace); + + const result = await hybridIndex.fetch(["id0", "id1", "id2", "id3"], { + includeVectors: true, + includeMetadata: true, + includeData: true, + namespace, + }); + + expect(result).toEqual([ + { + id: "id0", + metadata: undefined, + data: undefined, + vector: [0.1, 0.2], + sparseVector: [ + [0, 1], + [0.1, 0.2], + ], + }, + { + id: "id1", + metadata: { key: "value" }, + data: undefined, + vector: [0.2, 0.3], + sparseVector: [ + [0, 1], + [0.2, 0.3], + ], + }, + { + id: "id2", + metadata: { key: "value" }, + data: "data", + vector: [0.3, 0.4], + sparseVector: [ + [0, 1], + [0.3, 0.4], + ], + }, + // @ts-expect-error checking an index that doesn't exist + undefined, + ]); + }); }); diff --git a/src/commands/client/query/index.ts b/src/commands/client/query/index.ts index cc0c050..ae8b104 100644 --- a/src/commands/client/query/index.ts +++ b/src/commands/client/query/index.ts @@ -1,2 +1,3 @@ export * from "./query-many"; export * from "./query-single"; +export { FusionAlgorithm, QueryMode, WeightingStrategy } from "./types"; diff --git a/src/commands/client/query/query-many/index.test.ts b/src/commands/client/query/query-many/index.test.ts index d6a528c..0cdbee5 100644 --- a/src/commands/client/query/query-many/index.test.ts +++ b/src/commands/client/query/query-many/index.test.ts @@ -1,7 +1,15 @@ -import { afterAll, describe, expect, test } from "bun:test"; +import { afterAll, describe, expect, mock, test } from "bun:test"; import { UpsertCommand } from "@commands/client/upsert"; -import { Index, awaitUntilIndexed, newHttpClient, randomID, range } from "@utils/test-utils"; +import { + Index, + awaitUntilIndexed, + newHttpClient, + populateHybridIndex, + randomID, + range, +} from "@utils/test-utils"; import { QueryManyCommand } from "."; +import { FusionAlgorithm, WeightingStrategy } from "../types"; const client = newHttpClient(); @@ -77,9 +85,14 @@ describe("QUERY", () => { describe("QUERY with Index Client", () => { const index = new Index(); + const hybridIndex = new Index({ + token: process.env.HYBRID_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.HYBRID_UPSTASH_VECTOR_REST_URL!, + }); afterAll(async () => { await index.reset(); + await hybridIndex.reset({ all: true }); }); test("should query in batches successfully", async () => { const ID = randomID(); @@ -144,4 +157,48 @@ describe("QUERY with Index Client", () => { ], ]); }); + + test("should query hybrid index", async () => { + const namespace = "query-hybrid"; + const mockData = await populateHybridIndex(hybridIndex, namespace); + + const result = await index.queryMany( + [ + { + topK: 1, + vector: [0.1, 0.1], + sparseVector: [ + [3, 4], + [0.1, 0.2], + ], + fusionAlgorithm: FusionAlgorithm.RRF, + }, + { + topK: 1, + vector: [0.5, 0.1], + sparseVector: [ + [0, 1], + [0.5, 0.1], + ], + includeVectors: true, + }, + { + topK: 1, + sparseVector: [ + [2, 3], + [0.5, 0.5], + ], + weightingStrategy: WeightingStrategy.IDF, + fusionAlgorithm: FusionAlgorithm.DBSF, + includeMetadata: true, + }, + ], + { + namespace, + } + ); + + // @ts-expect-error will fix after testing with actual index + expect(result).toEqual("todo: fix with actual"); + }); }); diff --git a/src/commands/client/query/query-single/index.test.ts b/src/commands/client/query/query-single/index.test.ts index f73146d..f33aec3 100644 --- a/src/commands/client/query/query-single/index.test.ts +++ b/src/commands/client/query/query-single/index.test.ts @@ -1,6 +1,14 @@ import { afterAll, describe, expect, test } from "bun:test"; -import { QueryCommand, UpsertCommand } from "@commands/index"; -import { Index, awaitUntilIndexed, newHttpClient, randomID, range } from "@utils/test-utils"; +import { QueryCommand, QueryMode, UpsertCommand, WeightingStrategy } from "@commands/index"; +import { + Index, + awaitUntilIndexed, + newHttpClient, + populateHybridIndex, + populateSparseIndex, + randomID, + range, +} from "@utils/test-utils"; const client = newHttpClient(); @@ -223,6 +231,14 @@ describe("QUERY", () => { describe("QUERY with Index Client", () => { const index = new Index(); + const sparseIndex = new Index({ + token: process.env.SPARSE_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.SPARSE_UPSTASH_VECTOR_REST_URL!, + }); + const hybridIndex = new Index({ + token: process.env.HYBRID_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.HYBRID_UPSTASH_VECTOR_REST_URL!, + }); afterAll(async () => { await index.reset(); @@ -325,4 +341,162 @@ describe("QUERY with Index Client", () => { }, ]); }); + + test("should query sparse index", async () => { + const namespace = "query-sparse"; + await populateSparseIndex(sparseIndex, namespace); + + const result = await sparseIndex.query( + { + sparseVector: [ + [0, 1, 3], + [0.1, 0.5, 0.1], + ], + topK: 5, + includeVectors: true, + includeMetadata: true, + includeData: true, + weightingStrategy: WeightingStrategy.IDF, + }, + { + namespace, + } + ); + + expect(result).toEqual([ + { + id: "id0", + score: expect.any(Number), + metadata: undefined, + data: undefined, + sparseVector: [ + [0, 1], + [0.1, 0.2], + ], + }, + { + id: "id1", + score: expect.any(Number), + metadata: { key: "value" }, + data: undefined, + sparseVector: [ + [0, 1], + [0.2, 0.3], + ], + }, + { + id: "id2", + score: expect.any(Number), + metadata: { key: "value" }, + data: "data", + sparseVector: [ + [0, 1], + [0.3, 0.4], + ], + }, + // @ts-expect-error checking an index that doesn't exist + undefined, + ]); + }); + + test("should query hybrid index", async () => { + const namespace = "query-hybrid"; + await populateHybridIndex(hybridIndex, namespace); + + const result = await hybridIndex.query( + { + vector: [1, 2], + sparseVector: [ + [0, 1, 3], + [0.1, 0.5, 0.1], + ], + topK: 5, + includeVectors: true, + includeMetadata: true, + includeData: true, + }, + { + namespace, + } + ); + + expect(result).toEqual([ + { + id: "id0", + score: expect.any(Number), + metadata: undefined, + data: undefined, + vector: [0.1, 0.2], + sparseVector: [ + [0, 1], + [0.1, 0.2], + ], + }, + { + id: "id1", + score: expect.any(Number), + metadata: { key: "value" }, + data: undefined, + vector: [0.2, 0.3], + sparseVector: [ + [0, 1], + [0.2, 0.3], + ], + }, + { + id: "id2", + score: expect.any(Number), + metadata: { key: "value" }, + data: "data", + vector: [0.3, 0.4], + sparseVector: [ + [0, 1], + [0.3, 0.4], + ], + }, + // @ts-expect-error checking an index that doesn't exist + undefined, + ]); + }); + + test("should query hybrid index with query mode", async () => { + const namespace = "query-hybrid-query-mode"; + await hybridIndex.upsert( + [ + { + id: "id0", + data: "hello", + }, + { + id: "id0", + data: "hello world", + }, + { + id: "id0", + data: "hello world upstash", + }, + ], + { + namespace, + } + ); + + const result = await hybridIndex.query( + { + data: "upstash", + topK: 3, + includeData: true, + // queryMode: QueryMode.SPARSE, + includeVectors: true, + }, + { + namespace, + } + ); + + expect(result).toEqual( + // @ts-expect-error update after actual index + "TODO: update after test" + ); + }); }); diff --git a/src/commands/client/query/query-single/index.ts b/src/commands/client/query/query-single/index.ts index 4c39078..4524a62 100644 --- a/src/commands/client/query/query-single/index.ts +++ b/src/commands/client/query/query-single/index.ts @@ -5,6 +5,7 @@ import type { QueryEndpointVariants, QueryResult, } from "../types"; +import { UpstashError } from "@error/index"; export class QueryCommand extends Command[]> { constructor(payload: QueryCommandPayload, options?: QueryCommandOptions) { @@ -12,6 +13,8 @@ export class QueryCommand extends Command[]> { if ("data" in payload) { endpoint = "query-data"; + } else if (!payload.vector && !payload.sparseVector) { + throw new UpstashError("Either data, vector or sparseVector should be provided."); } if (options?.namespace) { diff --git a/src/commands/client/query/types.ts b/src/commands/client/query/types.ts index 94bf2ca..c592e28 100644 --- a/src/commands/client/query/types.ts +++ b/src/commands/client/query/types.ts @@ -1,4 +1,4 @@ -import type { Dict, NAMESPACE } from "../types"; +import type { Dict, NAMESPACE, RawSparseVector } from "../types"; export type QueryCommandPayload = { topK: number; @@ -6,12 +6,19 @@ export type QueryCommandPayload = { includeVectors?: boolean; includeMetadata?: boolean; includeData?: boolean; -} & ({ vector: number[]; data?: never } | { data: string; vector?: never }); + weightingStrategy?: WeightingStrategy; + fusionAlgorithm?: FusionAlgorithm; + queryMode?: QueryMode; +} & ( + | { vector?: number[]; sparseVector?: RawSparseVector; data?: never } + | { data: string; vector?: never; sparseVector?: never } +); export type QueryResult = { id: number | string; score: number; vector?: number[]; + sparseVector?: RawSparseVector; metadata?: TMetadata; data?: string; }; @@ -23,3 +30,118 @@ export type QueryEndpointVariants = | `query-data` | `query/${NAMESPACE}` | `query-data/${NAMESPACE}`; + +/** + * For sparse vectors, what kind of weighting strategy + * should be used while querying the matching non-zero + * dimension values of the query vector with the documents. + * + * If not provided, no weighting will be used. + */ +export enum WeightingStrategy { + /** + * Inverse document frequency. + * + * It is recommended to use this weighting strategy for + * BM25 sparse embedding models. + * + * It is calculated as + * + * ln(((N - n(q) + 0.5) / (n(q) + 0.5)) + 1) where + * N: Total number of sparse vectors. + * n(q): Total number of sparse vectors having non-zero value + * for that particular dimension. + * ln: Natural logarithm + * + * The values of N and n(q) are maintained by Upstash as the + * vectors are indexed. + */ + IDF = "IDF", +} + +/** + * Fusion algorithm to use while fusing scores + * from dense and sparse components of a hybrid index. + * + * If not provided, defaults to `RRF`. + */ +export enum FusionAlgorithm { + /** + * Reciprocal rank fusion. + * + * Each sorted score from the dense and sparse indexes are + * mapped to 1 / (rank + K), where rank is the order of the + * score in the dense or sparse scores and K is a constant + * with the value of 60. + * + * Then, scores from the dense and sparse components are + * deduplicated (i.e. if a score for the same vector is present + * in both dense and sparse scores, the mapped scores are + * added; otherwise individual mapped scores are used) + * and the final result is returned as the topK values + * of this final list. + * + * In short, this algorithm just takes the order of the scores + * into consideration. + */ + RRF = "RRF", + + /** + * Distribution based score fusion. + * + * Each sorted score from the dense and sparse indexes are + * normalized as + * (s - (mean - 3 * stddev)) / ((mean + 3 * stddev) - (mean - 3 * stddev)) + * where s is the score, (mean - 3 * stddev) is the minimum, + * and (mean + 3 * stddev) is the maximum tail ends of the distribution. + * + * Then, scores from the dense and sparse components are + * deduplicated (i.e. if a score for the same vector is present + * in both dense and sparse scores, the normalized scores are + * added; otherwise individual normalized scores are used) + * and the final result is returned as the topK values + * of this final list. + * + * In short, this algorithm takes distribution of the scores + * into consideration as well, as opposed to the `RRF`. + */ + DBSF = "DBSF", +} + +/** + * Query mode for hybrid indexes with Upstash-hosted + * embedding models. + * + * Specifies whether to run the query in only the + * dense index, only the sparse index, or in both. + * + * If not provided, defaults to `HYBRID`. + */ +export enum QueryMode { + /** + * Runs the query in hybrid index mode, after embedding + * the raw text data into dense and sparse vectors. + * + * Query results from the dense and sparse index components + * of the hybrid index are fused before returning the result. + */ + HYBRID = "HYBRID", + + /** + * Runs the query in dense index mode, after embedding + * the raw text data into a dense vector. + * + * Only the query results from the dense index component + * of the hybrid index is returned. + */ + DENSE = "DENSE", + + /** + * Runs the query in sparse index mode, after embedding + * the raw text data into a sparse vector. + * + * Only the query results from the sparse index component + * of the hybrid index is returned. + */ + SPARSE = "SPARSE", +} diff --git a/src/commands/client/range/index.test.ts b/src/commands/client/range/index.test.ts index 8fe8e5e..2b97de8 100644 --- a/src/commands/client/range/index.test.ts +++ b/src/commands/client/range/index.test.ts @@ -38,6 +38,14 @@ describe("RANGE with Index Client", () => { token: process.env.UPSTASH_VECTOR_REST_TOKEN!, url: process.env.UPSTASH_VECTOR_REST_URL!, }); + const sparseIndex = new Index({ + token: process.env.SPARSE_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.SPARSE_UPSTASH_VECTOR_REST_URL!, + }); + const hybridIndex = new Index({ + token: process.env.HYBRID_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.HYBRID_UPSTASH_VECTOR_REST_URL!, + }); afterAll(async () => { await index.reset(); @@ -59,4 +67,122 @@ describe("RANGE with Index Client", () => { expect(res.nextCursor).toBe("5"); }); + + test("should use range for sparse", async () => { + const namespace = "range-sparse"; + + const vectors: ConstructorParameters[0] = Array.from( + { length: 20 }, + (_, i) => ({ + id: `id-${i}`, + sparseVector: [ + [Math.floor(Math.random() * 11), Math.floor(Math.random() * 11)], + [Math.random(), Math.random()], + ], + metadata: { meta: i }, + data: `data-${i}`, + }) + ); + + await sparseIndex.upsert(vectors, { namespace }); + await awaitUntilIndexed(sparseIndex); + + let res = await sparseIndex.range( + { + cursor: "", + limit: 4, + includeVectors: true, + includeMetadata: true, + includeData: true, + }, + { + namespace, + } + ); + + // Initial batch assertions + expect(res.vectors.length).toBe(4); + expect(res.nextCursor).not.toBe(""); + + for (let i = 0; i < 4; i++) { + const vector = res.vectors[i]; + expect(vector.id).toBe(`id-${i}`); + expect(vector.metadata).toEqual({ meta: i }); + expect(vector.data).toBe(`data-${i}`); + expect(vector.sparseVector).not.toBeNull(); + } + + // Paginate through remaining results + while (res.nextCursor !== "") { + res = await sparseIndex.range( + { + cursor: res.nextCursor, + limit: 8, + includeVectors: true, + }, + { namespace } + ); + expect(res.vectors.length).toBe(8); + } + }); + + test("should use range for hybrid", async () => { + const namespace = "range-hybrid"; + + const vectors: ConstructorParameters[0] = Array.from( + { length: 20 }, + (_, i) => ({ + id: `id-${i}`, + vector: [Math.random(), Math.random()], + sparseVector: [ + [Math.floor(Math.random() * 11), Math.floor(Math.random() * 11)], + [Math.random(), Math.random()], + ], + metadata: { meta: i }, + data: `data-${i}`, + }) + ); + + await hybridIndex.upsert(vectors, { namespace }); + await awaitUntilIndexed(hybridIndex); + + let res = await hybridIndex.range( + { + cursor: "", + limit: 4, + includeVectors: true, + includeMetadata: true, + includeData: true, + }, + { + namespace, + } + ); + + // Initial batch assertions + expect(res.vectors.length).toBe(4); + expect(res.nextCursor).not.toBe(""); + + for (let i = 0; i < 4; i++) { + const vector = res.vectors[i]; + expect(vector.id).toBe(`id-${i}`); + expect(vector.metadata).toEqual({ meta: i }); + expect(vector.data).toBe(`data-${i}`); + expect(vector.sparseVector).not.toBeNull(); + expect(vector.vector).not.toBeNull(); + } + + // Paginate through remaining results + while (res.nextCursor !== "") { + res = await hybridIndex.range( + { + cursor: res.nextCursor, + limit: 8, + includeVectors: true, + }, + { namespace } + ); + expect(res.vectors.length).toBe(8); + } + }); }); diff --git a/src/commands/client/resumable-query/index.test.ts b/src/commands/client/resumable-query/index.test.ts index 768e8e4..8d30233 100644 --- a/src/commands/client/resumable-query/index.test.ts +++ b/src/commands/client/resumable-query/index.test.ts @@ -1,10 +1,25 @@ import { afterAll, describe, expect, test } from "bun:test"; -import { Index, awaitUntilIndexed, range } from "@utils/test-utils"; +import { + Index, + awaitUntilIndexed, + populateHybridIndex, + populateSparseIndex, + range, +} from "@utils/test-utils"; import { sleep } from "bun"; +import { FusionAlgorithm, WeightingStrategy } from "../query/types"; describe("RESUMABLE QUERY", () => { const index = new Index(); + const sparseIndex = new Index({ + token: process.env.SPARSE_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.SPARSE_UPSTASH_VECTOR_REST_URL!, + }); + const hybridIndex = new Index({ + token: process.env.HYBRID_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.HYBRID_UPSTASH_VECTOR_REST_URL!, + }); afterAll(async () => { await index.reset(); }); @@ -110,4 +125,123 @@ describe("RESUMABLE QUERY", () => { }, { timeout: 10_000 } ); + + test("should use resumable query for sparse index", async () => { + // Mock hybrid index object + const namespace = "resumable-sparse"; + await populateSparseIndex(sparseIndex, namespace); + // Assertion logic + const { result, fetchNext, stop } = await sparseIndex.resumableQuery( + { + sparseVector: [[0], [0.1]], + topK: 2, + includeVectors: true, + includeMetadata: true, + includeData: true, + weightingStrategy: WeightingStrategy.IDF, + fusionAlgorithm: FusionAlgorithm.DBSF, + maxIdle: 3600, + }, + { + namespace, + } + ); + + try { + expect(result.length).toBe(2); + + // Validate first result + expect(result[0].id).toBe("id0"); + expect(result[0].vector).toBeUndefined(); + expect(result[0].sparseVector).toEqual([ + [0, 1], + [0.3, 0.1], + ]); + + // Validate second result + expect(result[1].id).toBe("id1"); + expect(result[1].metadata).toEqual({ key: "value" }); + expect(result[1].vector).toBeUndefined(); + expect(result[1].sparseVector).toEqual([ + [0, 2], + [0.2, 0.1], + ]); + + // Fetch next result + const nextResult = await fetchNext(1); + expect(nextResult.length).toBe(1); + + // Validate next result + expect(nextResult[0].id).toBe("id2"); + expect(nextResult[0].metadata).toEqual({ key: "value" }); + expect(nextResult[0].data).toBe("data"); + expect(nextResult[0].vector).toBeUndefined(); + expect(nextResult[0].sparseVector).toEqual([ + [0, 3], + [0.1, 0.1], + ]); + } finally { + await stop(); + } + }); + + test("should use resumable query for hybrid index", async () => { + // Mock hybrid index object + const namespace = "resumable-hybrid"; + await populateHybridIndex(hybridIndex, namespace); + // Assertion logic + const { result, fetchNext, stop } = await hybridIndex.resumableQuery( + { + vector: [0.1, 0.1], + sparseVector: [[0], [0.1]], + topK: 2, + includeVectors: true, + includeMetadata: true, + includeData: true, + weightingStrategy: WeightingStrategy.IDF, + fusionAlgorithm: FusionAlgorithm.DBSF, + maxIdle: 3600, + }, + { + namespace, + } + ); + + try { + expect(result.length).toBe(2); + + // Validate first result + expect(result[0].id).toBe("id0"); + expect(result[0].vector).toEqual([0.9, 0.9]); + expect(result[0].sparseVector).toEqual([ + [0, 1], + [0.3, 0.1], + ]); + + // Validate second result + expect(result[1].id).toBe("id1"); + expect(result[1].metadata).toEqual({ key: "value" }); + expect(result[1].vector).toEqual([0.8, 0.9]); + expect(result[1].sparseVector).toEqual([ + [0, 2], + [0.2, 0.1], + ]); + + // Fetch next result + const nextResult = await fetchNext(1); + expect(nextResult.length).toBe(1); + + // Validate next result + expect(nextResult[0].id).toBe("id2"); + expect(nextResult[0].metadata).toEqual({ key: "value" }); + expect(nextResult[0].data).toBe("data"); + expect(nextResult[0].vector).toEqual([0.7, 0.9]); + expect(nextResult[0].sparseVector).toEqual([ + [0, 3], + [0.1, 0.1], + ]); + } finally { + await stop(); + } + }); }); diff --git a/src/commands/client/types.ts b/src/commands/client/types.ts index 7b2659a..7f574d7 100644 --- a/src/commands/client/types.ts +++ b/src/commands/client/types.ts @@ -1,6 +1,7 @@ export type Vector = { id: string; vector?: number[]; + sparseVector?: RawSparseVector; metadata?: TMetadata; data?: string; }; @@ -8,3 +9,5 @@ export type Vector = { export type NAMESPACE = string; export type Dict = Record; + +export type RawSparseVector = [index: number[], factors: number[]]; diff --git a/src/commands/client/update/index.test.ts b/src/commands/client/update/index.test.ts index 7e989fb..9687639 100644 --- a/src/commands/client/update/index.test.ts +++ b/src/commands/client/update/index.test.ts @@ -1,6 +1,14 @@ import { afterAll, describe, expect, test } from "bun:test"; import { FetchCommand, UpdateCommand, UpsertCommand } from "@commands/index"; -import { Index, awaitUntilIndexed, newHttpClient, range, resetIndexes } from "@utils/test-utils"; +import { + Index, + awaitUntilIndexed, + newHttpClient, + populateHybridIndex, + populateSparseIndex, + range, + resetIndexes, +} from "@utils/test-utils"; const client = newHttpClient(); @@ -95,8 +103,18 @@ describe("UPDATE with Index Client", () => { token: process.env.UPSTASH_VECTOR_REST_TOKEN!, url: process.env.UPSTASH_VECTOR_REST_URL!, }); + + const sparseIndex = new Index({ + token: process.env.SPARSE_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.SPARSE_UPSTASH_VECTOR_REST_URL!, + }); + const hybridIndex = new Index({ + token: process.env.HYBRID_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.HYBRID_UPSTASH_VECTOR_REST_URL!, + }); + afterAll(async () => { - await index.reset(); + await resetIndexes(); }); test("should update vector metadata", async () => { @@ -139,4 +157,76 @@ describe("UPDATE with Index Client", () => { expect(fetchData[0]?.metadata?.upstashRocks).toBe("test-update"); }); + + test("should update sparse index", async () => { + const namespace = "update-sparse"; + await populateSparseIndex(sparseIndex, namespace); + + const { updated } = await sparseIndex.update( + { + id: "id1", + sparseVector: [ + [6, 7], + [0.5, 0.6], + ], + }, + { + namespace, + } + ); + + expect(updated).toBe(1); + + const res = await sparseIndex.fetch(["id1"], { + includeVectors: true, + namespace, + }); + + // Assert fetch results + expect(res.length).toBe(1); + expect(res[0]).not.toBeNull(); + const vector = res[0]!; + expect(vector.id).toBe("id1"); + expect(vector.sparseVector).toEqual([ + [6, 7], + [0.5, 0.6], + ]); + }); + + test("should update hybrid index", async () => { + const namespace = "update-hybrid"; + await populateHybridIndex(hybridIndex, namespace); + + const { updated } = await hybridIndex.update( + { + id: "id1", + vector: [0.5, 0.6], + sparseVector: [ + [6, 7], + [0.5, 0.6], + ], + }, + { + namespace, + } + ); + + expect(updated).toBe(1); + + const res = await sparseIndex.fetch(["id1"], { + includeVectors: true, + namespace, + }); + + // Assert fetch results + expect(res.length).toBe(1); + expect(res[0]).not.toBeNull(); + const vector = res[0]!; + expect(vector.id).toBe("id1"); + expect(vector.vector).toEqual([0.5, 0.6]); + expect(vector.sparseVector).toEqual([ + [6, 7], + [0.5, 0.6], + ]); + }); }); diff --git a/src/commands/client/update/index.ts b/src/commands/client/update/index.ts index cba0d24..1f35a5b 100644 --- a/src/commands/client/update/index.ts +++ b/src/commands/client/update/index.ts @@ -1,5 +1,6 @@ -import type { NAMESPACE } from "@commands/client/types"; +import type { NAMESPACE, RawSparseVector } from "@commands/client/types"; import { Command } from "@commands/command"; +import { UpstashError } from "@error/index"; type NoInfer = T extends infer U ? U : never; @@ -11,7 +12,8 @@ type MetadataUpdatePayload = { type VectorUpdatePayload = { id: string | number; - vector: number[]; + vector?: number[]; + sparseVector?: RawSparseVector; }; type DataUpdatePayload = { @@ -33,6 +35,18 @@ export class UpdateCommand extends Command { constructor(payload: Payload, opts?: UpdateCommandOptions) { let endpoint: UpdateEndpointVariants = "update"; + if ( + !("metadata" in payload) && + !("vector" in payload) && + !("sparseVector" in payload) && + !("data" in payload) + ) { + throw new UpstashError( + `Error while updating vector with id ${payload.id}.` + + `At least one of 'metadata', 'vector', 'sparseVector' or 'data' should be provided.` + ); + } + if (opts?.namespace) { endpoint = `${endpoint}/${opts.namespace}`; } diff --git a/src/commands/client/upsert/index.test.ts b/src/commands/client/upsert/index.test.ts index 91c6b6b..c3e32dc 100644 --- a/src/commands/client/upsert/index.test.ts +++ b/src/commands/client/upsert/index.test.ts @@ -4,6 +4,8 @@ import { Index, awaitUntilIndexed, newHttpClient, + populateHybridIndex, + populateSparseIndex, randomID, range, resetIndexes, @@ -22,7 +24,6 @@ describe("UPSERT", () => { test("should return an error when vector is missing", () => { // eslint-disable-next-line unicorn/consistent-function-scoping const throwable = async () => { - //@ts-expect-error Missing vector field in upsert command. await new UpsertCommand({ id: 1 }).exec(client); }; expect(throwable).toThrow(); @@ -102,7 +103,6 @@ describe("UPSERT with Index Client", () => { test("should return an error when vector is missing", () => { // eslint-disable-next-line unicorn/consistent-function-scoping const throwable = async () => { - //@ts-expect-error Missing vector field in upsert command. await new UpsertCommand({ id: 1 }).exec(client); }; expect(throwable).toThrow(); @@ -173,6 +173,16 @@ describe("UPSERT with Index Client", () => { describe("Upsert with new data field", () => { afterAll(async () => await resetIndexes()); + + const sparseIndex = new Index({ + token: process.env.SPARSE_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.SPARSE_UPSTASH_VECTOR_REST_URL!, + }); + const hybridIndex = new Index({ + token: process.env.HYBRID_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.HYBRID_UPSTASH_VECTOR_REST_URL!, + }); + test("should add data to data field - /upsert-data", async () => { const id = randomID(); const data = "testing data"; @@ -223,4 +233,38 @@ describe("Upsert with new data field", () => { expect(result.map((r) => r.data)).toEqual([data]); }); + + test("should upsert to sparse", async () => { + const namespace = "upsert-sparse"; + + // populate will upsert vectors + const mockData = await populateSparseIndex(sparseIndex, namespace); + + const result = await sparseIndex.fetch(mockData.map((vector) => vector.id) as string[], { + includeVectors: true, + namespace, + }); + + expect(result).toBe( + // @ts-expect-error will be updated after running with actual index + "TODO: update after using idnex" + ); + }); + + test("should upsert to hybrid", async () => { + const namespace = "upsert-hybrid"; + + // populate will upsert vectors + const mockData = await populateHybridIndex(hybridIndex, namespace); + + const result = await hybridIndex.fetch(mockData.map((vector) => vector.id) as string[], { + includeVectors: true, + namespace, + }); + + expect(result).toBe( + // @ts-expect-error will be updated after running with actual index + "TODO: update after using idnex" + ); + }); }); diff --git a/src/commands/client/upsert/index.ts b/src/commands/client/upsert/index.ts index 714619a..44386d1 100644 --- a/src/commands/client/upsert/index.ts +++ b/src/commands/client/upsert/index.ts @@ -1,11 +1,12 @@ -import type { NAMESPACE } from "@commands/client/types"; +import type { NAMESPACE, RawSparseVector } from "@commands/client/types"; import { Command } from "@commands/command"; type NoInfer = T extends infer U ? U : never; type VectorPayload = { id: number | string; - vector: number[]; + vector?: number[]; + sparseVector?: RawSparseVector; metadata?: NoInfer; }; @@ -50,5 +51,6 @@ export class UpsertCommand extends Command { const isVectorPayload = ( payload: VectorPayload | DataPayload ): payload is VectorPayload => { - return "vector" in payload; + // TODO: fix field name + return "vector" in payload || "sparseVector" in payload; }; diff --git a/src/utils/test-utils.ts b/src/utils/test-utils.ts index 702739e..9f3b465 100644 --- a/src/utils/test-utils.ts +++ b/src/utils/test-utils.ts @@ -3,6 +3,7 @@ import { InfoCommand } from "../commands/client/info"; import { ResetCommand } from "../commands/client/reset"; import { HttpClient, type RetryConfig } from "../http"; import type { Index } from "../vector"; +import { UpsertCommand } from "@commands/client"; export * from "../platforms/nodejs"; export type NonArrayType = T extends Array ? U : T; @@ -53,7 +54,22 @@ export function randomID(): string { export const randomFloat = () => Number.parseFloat((Math.random() * 10).toFixed(1)); -export const resetIndexes = async () => await new ResetCommand().exec(newHttpClient()); +export const resetIndexes = async () => + await Promise.all([ + new ResetCommand({ all: true }).exec(newHttpClient()), + new ResetCommand({ all: true }).exec( + newHttpClient(undefined, { + token: process.env.SPARSE_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.SPARSE_UPSTASH_VECTOR_REST_URL!, + }) + ), + new ResetCommand({ all: true }).exec( + newHttpClient(undefined, { + token: process.env.HYBRID_UPSTASH_VECTOR_REST_TOKEN!, + url: process.env.HYBRID_UPSTASH_VECTOR_REST_URL!, + }) + ), + ]); export const range = (start: number, end: number, step = 1) => { const result = []; @@ -89,3 +105,74 @@ export const awaitUntilIndexed = async (client: HttpClient | Index, timeoutMilli throw new Error(`Indexing is not completed in ${timeoutMillis} ms.`); }; + +export const populateSparseIndex = async (index: Index, namespace: string) => { + const mockData: ConstructorParameters[0] = [ + { + id: "id0", + sparseVector: [ + [0, 1], + [0.1, 0.2], + ], + }, + { + id: "id1", + sparseVector: [ + [1, 2], + [0.2, 0.3], + ], + metadata: { key: "value" }, + }, + ]; + mockData.push({ + id: "id2", + sparseVector: [ + [2, 3], + [0.3, 0.4], + ], + metadata: { key: "value" }, + // @ts-expect-error data field isn't allowed because this is + // a vector payload. but we allow it for the test purposes + data: "data", + }); + await index.upsert(mockData, { namespace }); + await awaitUntilIndexed(index); + return mockData; +}; + +export const populateHybridIndex = async (index: Index, namespace: string) => { + const mockData: ConstructorParameters[0] = [ + { + id: "id0", + vector: [0.1, 0.2], + sparseVector: [ + [0, 1], + [0.1, 0.2], + ], + }, + { + id: "id1", + vector: [0.2, 0.3], + sparseVector: [ + [1, 2], + [0.2, 0.3], + ], + metadata: { key: "value" }, + }, + ]; + mockData.push({ + id: "id2", + vector: [0.3, 0.4], + sparseVector: [ + [2, 3], + [0.3, 0.4], + ], + metadata: { key: "value" }, + // @ts-expect-error data field isn't allowed because this is + // a vector payload. but we allow it for the test purposes + data: "data", + }); + await index.upsert(mockData, { namespace }); + await awaitUntilIndexed(index); + return mockData; +};