diff --git a/README.md b/README.md index f90be20a312e..ad93ed3ce3ec 120000 --- a/README.md +++ b/README.md @@ -1 +1 @@ -packages/ai/README.md \ No newline at end of file +packages/ai/README.md diff --git a/packages/ai/core/rerank/__snapshots__/rerank.test.ts.snap b/packages/ai/core/rerank/__snapshots__/rerank.test.ts.snap new file mode 100644 index 000000000000..7ee762d31de5 --- /dev/null +++ b/packages/ai/core/rerank/__snapshots__/rerank.test.ts.snap @@ -0,0 +1,134 @@ +// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html + +exports[`telemetry > should not record telemetry inputs / outputs when disabled 1`] = ` +[ + { + "attributes": { + "ai.model.id": "mock-model-id", + "ai.model.provider": "mock-provider", + "ai.operationId": "ai.rerank", + "ai.usage.tokens": 10, + "operation.name": "ai.rerank", + }, + "events": [], + "name": "ai.rerank", + }, + { + "attributes": { + "ai.model.id": "mock-model-id", + "ai.model.provider": "mock-provider", + "ai.operationId": "ai.rerank.doRerank", + "ai.usage.tokens": 10, + "operation.name": "ai.rerank.doRerank", + }, + "events": [], + "name": "ai.rerank.doRerank", + }, +] +`; + +exports[`telemetry > should record telemetry data when enabled (multiple calls path) 1`] = ` +[ + { + "attributes": { + "ai.model.id": "mock-model-id", + "ai.model.provider": "mock-provider", + "ai.operationId": "ai.rerank", + "ai.rerankedDocuments": "[]", + "ai.rerankedIndices": "[1,0]", + "ai.telemetry.functionId": "test-function-id", + "ai.telemetry.metadata.test1": "value1", + "ai.telemetry.metadata.test2": false, + "ai.usage.tokens": 20, + "ai.values": "["sunny day at the beach","rainy day in the city"]", + "operation.name": "ai.rerank test-function-id", + "resource.name": "test-function-id", + }, + "events": [], + "name": "ai.rerank", + }, + { + "attributes": { + "ai.model.id": "mock-model-id", + "ai.model.provider": "mock-provider", + "ai.operationId": "ai.rerank.doRerank", + "ai.rerankedDocuments": "[]", + "ai.rerankedIndices": "[1]", + "ai.telemetry.functionId": "test-function-id", + "ai.telemetry.metadata.test1": "value1", + "ai.telemetry.metadata.test2": false, + "ai.usage.tokens": 10, + "ai.values": [ + "["sunny day at the beach"]", + ], + "operation.name": "ai.rerank.doRerank test-function-id", + "resource.name": "test-function-id", + }, + "events": [], + "name": "ai.rerank.doRerank", + }, + { + "attributes": { + "ai.model.id": "mock-model-id", + "ai.model.provider": "mock-provider", + "ai.operationId": "ai.rerank.doRerank", + "ai.rerankedDocuments": "[]", + "ai.rerankedIndices": "[0]", + "ai.telemetry.functionId": "test-function-id", + "ai.telemetry.metadata.test1": "value1", + "ai.telemetry.metadata.test2": false, + "ai.usage.tokens": 10, + "ai.values": [ + "["rainy day in the city"]", + ], + "operation.name": "ai.rerank.doRerank test-function-id", + "resource.name": "test-function-id", + }, + "events": [], + "name": "ai.rerank.doRerank", + }, +] +`; + +exports[`telemetry > should record telemetry data when enabled (single call path) 1`] = ` +[ + { + "attributes": { + "ai.model.id": "mock-model-id", + "ai.model.provider": "mock-provider", + "ai.operationId": "ai.rerank", + "ai.rerankedDocuments": "["rainy day in the city","sunny day at the beach"]", + "ai.rerankedIndices": "[1,0]", + "ai.telemetry.functionId": "test-function-id", + "ai.telemetry.metadata.test1": "value1", + "ai.telemetry.metadata.test2": false, + "ai.usage.tokens": 10, + "ai.values": "["sunny day at the beach","rainy day in the city"]", + "operation.name": "ai.rerank test-function-id", + "resource.name": "test-function-id", + }, + "events": [], + "name": "ai.rerank", + }, + { + "attributes": { + "ai.model.id": "mock-model-id", + "ai.model.provider": "mock-provider", + "ai.operationId": "ai.rerank.doRerank", + "ai.rerankedDocuments": "["rainy day in the city","sunny day at the beach"]", + "ai.rerankedIndices": "[1,0]", + "ai.telemetry.functionId": "test-function-id", + "ai.telemetry.metadata.test1": "value1", + "ai.telemetry.metadata.test2": false, + "ai.usage.tokens": 10, + "ai.values": [ + "["sunny day at the beach","rainy day in the city"]", + ], + "operation.name": "ai.rerank.doRerank test-function-id", + "resource.name": "test-function-id", + }, + "events": [], + "name": "ai.rerank.doRerank", + }, +] +`; diff --git a/packages/ai/core/rerank/index.ts b/packages/ai/core/rerank/index.ts new file mode 100644 index 000000000000..67d7d1afe50e --- /dev/null +++ b/packages/ai/core/rerank/index.ts @@ -0,0 +1,2 @@ +export * from './rerank'; +export * from './rerank-result'; diff --git a/packages/ai/core/rerank/rerank-result.ts b/packages/ai/core/rerank/rerank-result.ts new file mode 100644 index 000000000000..b3cee091ce74 --- /dev/null +++ b/packages/ai/core/rerank/rerank-result.ts @@ -0,0 +1,40 @@ +import { RerankedDocumentIndex } from '../types'; +import { RerankingModelUsage } from '../types/usage'; + +/** +The result of an `rerank` call. +It contains the documents, the reranked indices, and additional information. + */ +export interface RerankResult { + /** + The documents that were reranked. + */ + readonly documents: Array; + + /** + The reranked indices. + */ + readonly rerankedIndices: Array; + + /** + * The reranked documents. + * Only available if `returnDocuments` was set to `true`. + * The order of the documents is the same as the order of the indices. + */ + readonly rerankedDocuments?: Array; + + /** + The reranking token usage. + */ + readonly usage: RerankingModelUsage; + + /** + Optional raw response data. + */ + readonly rawResponse?: { + /** + Response headers. + */ + headers?: Record; + }; +} diff --git a/packages/ai/core/rerank/rerank.test.ts b/packages/ai/core/rerank/rerank.test.ts new file mode 100644 index 000000000000..3242e873b3f5 --- /dev/null +++ b/packages/ai/core/rerank/rerank.test.ts @@ -0,0 +1,263 @@ +import assert from 'node:assert'; +import { MockTracer } from '../test/mock-tracer'; +import { rerank } from './rerank'; +import { + mockRerank, + MockRerankingModelV1, +} from '../test/mock-reranking-model-v1'; + +const dummyDocumentsIndices = [1, 0]; +const dummyDocuments = ['sunny day at the beach', 'rainy day in the city']; +const dummyRerankedDocuments = [ + 'rainy day in the city', + 'sunny day at the beach', +]; +const query = 'rainy day'; +const topK = 2; + +describe('result.reranking', () => { + it('should reranking documents', async () => { + const result = await rerank({ + model: new MockRerankingModelV1({ + maxDocumentsPerCall: 5, + doRerank: mockRerank( + dummyDocuments, + dummyDocumentsIndices, + dummyRerankedDocuments, + ), + }), + values: dummyDocuments, + query, + topK, + }); + + assert.deepStrictEqual(result.rerankedIndices, dummyDocumentsIndices); + }); + + it('should reranking documents when several calls are required', async () => { + let callCount = 0; + const result = await rerank({ + model: new MockRerankingModelV1({ + maxDocumentsPerCall: 1, + doRerank: async () => { + switch (callCount++) { + case 0: + return { + rerankedIndices: dummyDocumentsIndices.slice(0, 1), + }; + case 1: + return { + rerankedIndices: dummyDocumentsIndices.slice(1), + }; + default: + throw new Error('Unexpected call'); + } + }, + }), + values: dummyDocuments, + query, + topK, + }); + + assert.deepStrictEqual(result.rerankedIndices, dummyDocumentsIndices); + }); +}); + +describe('result.value', () => { + it('should include value in the result', async () => { + const result = await rerank({ + model: new MockRerankingModelV1({ + maxDocumentsPerCall: 5, + doRerank: mockRerank( + dummyDocuments, + dummyDocumentsIndices, + dummyRerankedDocuments, + ), + }), + values: dummyDocuments, + query, + topK, + returnDocuments: true, + }); + + assert.deepStrictEqual(result.rerankedDocuments, dummyRerankedDocuments); + }); +}); + +describe('result.usage', () => { + it('should include usage in the result', async () => { + let callCount = 0; + const result = await rerank({ + model: new MockRerankingModelV1({ + maxDocumentsPerCall: 1, + doRerank: async () => { + switch (callCount++) { + case 0: + return { + rerankedIndices: dummyDocumentsIndices.slice(0, 1), + usage: { tokens: 10 }, + }; + case 1: + return { + rerankedIndices: dummyDocumentsIndices.slice(1), + usage: { tokens: 10 }, + }; + default: + throw new Error('Unexpected call'); + } + }, + }), + values: dummyDocuments, + query, + topK, + }); + + assert.deepStrictEqual(result.usage, { tokens: 20 }); + }); +}); + +describe('options.headers', () => { + it('should set headers', async () => { + const result = await rerank({ + model: new MockRerankingModelV1({ + maxDocumentsPerCall: 5, + doRerank: async ({ headers }) => { + assert.deepStrictEqual(headers, { + 'custom-request-header': 'request-header-value', + }); + + return { rerankedIndices: dummyDocumentsIndices }; + }, + }), + values: dummyDocuments, + query, + topK, + headers: { 'custom-request-header': 'request-header-value' }, + }); + + assert.deepStrictEqual(result.rerankedIndices, dummyDocumentsIndices); + }); +}); + +describe('telemetry', () => { + let tracer: MockTracer; + + beforeEach(() => { + tracer = new MockTracer(); + }); + + it('should not record any telemetry data when not explicitly enabled', async () => { + await rerank({ + model: new MockRerankingModelV1({ + maxDocumentsPerCall: 5, + doRerank: mockRerank( + dummyDocuments, + dummyDocumentsIndices, + dummyRerankedDocuments, + ), + }), + values: dummyDocuments, + query, + topK, + }); + assert.deepStrictEqual(tracer.jsonSpans, []); + }); + + it('should record telemetry data when enabled (single call path)', async () => { + await rerank({ + model: new MockRerankingModelV1({ + maxDocumentsPerCall: null, + doRerank: mockRerank( + dummyDocuments, + dummyDocumentsIndices, + dummyRerankedDocuments, + { + tokens: 10, + }, + ), + }), + values: dummyDocuments, + query, + topK, + experimental_telemetry: { + isEnabled: true, + functionId: 'test-function-id', + metadata: { + test1: 'value1', + test2: false, + }, + tracer, + }, + }); + + expect(tracer.jsonSpans).toMatchSnapshot(); + }); + + it('should record telemetry data when enabled (multiple calls path)', async () => { + let callCount = 0; + await rerank({ + model: new MockRerankingModelV1({ + maxDocumentsPerCall: 1, + doRerank: async ({ values }) => { + switch (callCount++) { + case 0: + assert.deepStrictEqual(values, dummyDocuments.slice(0, 1)); + return { + rerankedIndices: dummyDocumentsIndices.slice(0, 1), + usage: { tokens: 10 }, + }; + case 1: + assert.deepStrictEqual(values, dummyDocuments.slice(1)); + return { + rerankedIndices: dummyDocumentsIndices.slice(1), + usage: { tokens: 10 }, + }; + default: + throw new Error('Unexpected call'); + } + }, + }), + values: dummyDocuments, + query, + topK, + experimental_telemetry: { + isEnabled: true, + functionId: 'test-function-id', + metadata: { + test1: 'value1', + test2: false, + }, + tracer, + }, + }); + + expect(tracer.jsonSpans).toMatchSnapshot(); + }); + + it('should not record telemetry inputs / outputs when disabled', async () => { + await rerank({ + model: new MockRerankingModelV1({ + maxDocumentsPerCall: null, + doRerank: mockRerank( + dummyDocuments, + dummyDocumentsIndices, + dummyRerankedDocuments, + { + tokens: 10, + }, + ), + }), + values: dummyDocuments, + query, + topK, + experimental_telemetry: { + isEnabled: true, + recordInputs: false, + recordOutputs: false, + tracer, + }, + }); + + expect(tracer.jsonSpans).toMatchSnapshot(); + }); +}); diff --git a/packages/ai/core/rerank/rerank.ts b/packages/ai/core/rerank/rerank.ts new file mode 100644 index 000000000000..4141c2540e55 --- /dev/null +++ b/packages/ai/core/rerank/rerank.ts @@ -0,0 +1,310 @@ +import { RerankingModelV1DocumentIndex } from '@ai-sdk/provider'; +import { retryWithExponentialBackoff } from '../../util/retry-with-exponential-backoff'; +import { assembleOperationName } from '../telemetry/assemble-operation-name'; +import { getBaseTelemetryAttributes } from '../telemetry/get-base-telemetry-attributes'; +import { getTracer } from '../telemetry/get-tracer'; +import { recordSpan } from '../telemetry/record-span'; +import { selectTelemetryAttributes } from '../telemetry/select-telemetry-attributes'; +import { TelemetrySettings } from '../telemetry/telemetry-settings'; +import { RerankingModel } from '../types'; +import { splitArray } from '../util/split-array'; +import { RerankResult } from './rerank-result'; + +/** +Rerank documents using an reranking model. The type of the value is defined by the reranking model. + +@param model - The Reranking model to use. +@param values - The documents that should be reranking. +@param query - The query is a string that represents the query to rerank the documents against. +@param topK - Top k documents to rerank. +@param returnDocuments - Return the reranked documents in the response (In same order as indices). + +@param maxRetries - Maximum number of retries. Set to 0 to disable retries. Default: 2. +@param abortSignal - An optional abort signal that can be used to cancel the call. +@param headers - Additional HTTP headers to be sent with the request. Only applicable for HTTP-based providers. + +@returns A result object that contains the reranked documents, the reranked indices, and additional information. + */ +export async function rerank({ + model, + values, + query, + topK, + returnDocuments = false, + maxRetries, + abortSignal, + headers, + experimental_telemetry: telemetry, +}: { + /** +The reranking model to use. + */ + model: RerankingModel; + + /** +The documents that should be reranked. + */ + values: Array; + + /** +The query is a string that represents the query to rerank the documents against. + */ + query: string; + + /** +Top k documents to rerank. + */ + topK: number; + + /** +Return the reranked documents in the response (In same order as indices). + +@default false + */ + returnDocuments?: boolean; + + /** +Maximum number of retries per reranking model call. Set to 0 to disable retries. + +@default 2 + */ + maxRetries?: number; + + /** +Abort signal. + */ + abortSignal?: AbortSignal; + + /** +Additional headers to include in the request. +Only applicable for HTTP-based providers. + */ + headers?: Record; + + /** + * Optional telemetry configuration (experimental). + */ + experimental_telemetry?: TelemetrySettings; +}): Promise> { + const baseTelemetryAttributes = getBaseTelemetryAttributes({ + model, + telemetry, + headers, + settings: { maxRetries }, + }); + + const tracer = getTracer(telemetry); + + return recordSpan({ + name: 'ai.rerank', + attributes: selectTelemetryAttributes({ + telemetry, + attributes: { + ...assembleOperationName({ operationId: 'ai.rerank', telemetry }), + ...baseTelemetryAttributes, + 'ai.values': { input: () => JSON.stringify(values) }, + }, + }), + tracer, + fn: async span => { + const retry = retryWithExponentialBackoff({ maxRetries }); + const maxDocumentsPerCall = model.maxDocumentsPerCall; + + if (maxDocumentsPerCall == null) { + const { rerankedIndices, usage, rerankedDocuments, rawResponse } = + await retry(() => + recordSpan({ + name: 'ai.rerank.doRerank', + attributes: selectTelemetryAttributes({ + telemetry, + attributes: { + ...assembleOperationName({ + operationId: 'ai.rerank.doRerank', + telemetry, + }), + ...baseTelemetryAttributes, + // specific settings that only make sense on the outer level: + 'ai.values': { input: () => [JSON.stringify(values)] }, + }, + }), + tracer, + fn: async doRerankSpan => { + const modelResponse = await model.doRerank({ + values, + query, + topK, + returnDocuments, + abortSignal, + headers, + }); + + const rerankedIndices = modelResponse.rerankedIndices; + const usage = modelResponse.usage ?? { tokens: NaN }; + const rerankedDocuments = modelResponse.rerankedDocuments ?? []; + + doRerankSpan.setAttributes( + selectTelemetryAttributes({ + telemetry, + attributes: { + 'ai.rerankedIndices': { + output: () => JSON.stringify(rerankedIndices), + }, + 'ai.usage.tokens': usage.tokens, + 'ai.rerankedDocuments': { + output: () => JSON.stringify(rerankedDocuments), + }, + }, + }), + ); + + return { + rerankedIndices, + usage, + rerankedDocuments, + rawResponse: modelResponse.rawResponse, + }; + }, + }), + ); + + span.setAttributes( + selectTelemetryAttributes({ + telemetry, + attributes: { + 'ai.rerankedIndices': { + output: () => JSON.stringify(rerankedIndices), + }, + 'ai.rerankedDocuments': { + output: () => JSON.stringify(rerankedDocuments), + }, + 'ai.usage.tokens': usage.tokens, + }, + }), + ); + + return new DefaultRerankResult({ + values, + rerankedIndices, + rerankedDocuments, + usage, + rawResponse, + }); + } + + // split the values into chunks that are small enough for the model: + const valueChunks = splitArray(values, maxDocumentsPerCall); + + const rerankedIndices: Array = []; + const rerankedDocuments: Array = []; + let tokens = 0; + + for (const chunk of valueChunks) { + const { + rerankedIndices: chunkIndices, + rerankedDocuments: chunkedRerankedDocuments, + usage, + } = await retry(() => + recordSpan({ + name: 'ai.rerank.doRerank', + attributes: selectTelemetryAttributes({ + telemetry, + attributes: { + ...assembleOperationName({ + operationId: 'ai.rerank.doRerank', + telemetry, + }), + ...baseTelemetryAttributes, + 'ai.values': { input: () => [JSON.stringify(chunk)] }, + }, + }), + tracer, + fn: async doRerankSpan => { + const modelResponse = await model.doRerank({ + values: chunk, + query, + topK, + returnDocuments, + abortSignal, + headers, + }); + + const chunkIndices = modelResponse.rerankedIndices; + const usage = modelResponse.usage ?? { tokens: NaN }; + const chunkedRerankedDocuments = + modelResponse.rerankedDocuments ?? []; + + doRerankSpan.setAttributes( + selectTelemetryAttributes({ + telemetry, + attributes: { + 'ai.rerankedIndices': { + output: () => JSON.stringify(chunkIndices), + }, + 'ai.usage.tokens': usage.tokens, + 'ai.rerankedDocuments': { + output: () => JSON.stringify(chunkedRerankedDocuments), + }, + }, + }), + ); + + return { + rerankedIndices: chunkIndices, + usage, + rerankedDocuments: chunkedRerankedDocuments, + }; + }, + }), + ); + + rerankedIndices.push(...chunkIndices); + rerankedDocuments.push(...chunkedRerankedDocuments); + tokens += usage.tokens; + } + + span.setAttributes( + selectTelemetryAttributes({ + telemetry, + attributes: { + 'ai.rerankedIndices': { + output: () => JSON.stringify(rerankedIndices), + }, + 'ai.rerankedDocuments': { + output: () => JSON.stringify(rerankedDocuments), + }, + 'ai.usage.tokens': tokens, + }, + }), + ); + + return new DefaultRerankResult({ + values, + rerankedIndices, + rerankedDocuments, + usage: { tokens }, + }); + }, + }); +} + +class DefaultRerankResult implements RerankResult { + readonly documents: RerankResult['documents']; + readonly rerankedIndices: RerankResult['rerankedIndices']; + readonly rerankedDocuments: RerankResult['documents']; + readonly usage: RerankResult['usage']; + readonly rawResponse: RerankResult['rawResponse']; + + constructor(options: { + values: RerankResult['documents']; + rerankedIndices: RerankResult['rerankedIndices']; + rerankedDocuments: RerankResult['documents']; + usage: RerankResult['usage']; + rawResponse?: RerankResult['rawResponse']; + }) { + this.documents = options.values; + this.rerankedIndices = options.rerankedIndices; + this.rerankedDocuments = options.rerankedDocuments; + this.usage = options.usage; + this.rawResponse = options.rawResponse; + } +} diff --git a/packages/ai/core/test/mock-reranking-model-v1.ts b/packages/ai/core/test/mock-reranking-model-v1.ts new file mode 100644 index 000000000000..ebeb72c444c0 --- /dev/null +++ b/packages/ai/core/test/mock-reranking-model-v1.ts @@ -0,0 +1,51 @@ +import { RerankingModelV1 } from '@ai-sdk/provider'; +import { RerankedDocumentIndex } from '../types'; +import { RerankingModelUsage } from '../types/usage'; + +export class MockRerankingModelV1 implements RerankingModelV1 { + readonly specificationVersion = 'v1'; + + readonly provider: RerankingModelV1['provider']; + readonly modelId: RerankingModelV1['modelId']; + readonly maxDocumentsPerCall: RerankingModelV1['maxDocumentsPerCall']; + readonly supportsParallelCalls: RerankingModelV1['supportsParallelCalls']; + readonly returnInput: RerankingModelV1['returnInput'] = false; + + doRerank: RerankingModelV1['doRerank']; + + constructor({ + provider = 'mock-provider', + modelId = 'mock-model-id', + maxDocumentsPerCall = 1, + supportsParallelCalls = false, + doRerank = notImplemented, + }: { + provider?: RerankingModelV1['provider']; + modelId?: RerankingModelV1['modelId']; + maxDocumentsPerCall?: RerankingModelV1['maxDocumentsPerCall'] | null; + supportsParallelCalls?: RerankingModelV1['supportsParallelCalls']; + doRerank?: RerankingModelV1['doRerank']; + } = {}) { + this.provider = provider; + this.modelId = modelId; + this.maxDocumentsPerCall = maxDocumentsPerCall ?? undefined; + this.supportsParallelCalls = supportsParallelCalls; + this.doRerank = doRerank; + } +} + +export function mockRerank( + expectedValues: Array, + rerankedIndices: Array, + rerankedDocuments?: Array, + usage?: RerankingModelUsage, +): RerankingModelV1['doRerank'] { + return async ({ values }) => { + assert.deepStrictEqual(expectedValues, values); + return { rerankedIndices, rerankedDocuments, usage }; + }; +} + +function notImplemented(): never { + throw new Error('Not implemented'); +} diff --git a/packages/ai/core/types/index.ts b/packages/ai/core/types/index.ts index 8bdb0ffb44c8..25ed2b1d3827 100644 --- a/packages/ai/core/types/index.ts +++ b/packages/ai/core/types/index.ts @@ -1,4 +1,5 @@ export type { Embedding, EmbeddingModel } from './embedding-model'; +export type { RerankingModel, RerankedDocumentIndex } from './reranking-model'; export type { ImageModel, ImageGenerationWarning as ImageModelCallWarning, diff --git a/packages/ai/core/types/reranking-model.ts b/packages/ai/core/types/reranking-model.ts new file mode 100644 index 000000000000..e392d4e67792 --- /dev/null +++ b/packages/ai/core/types/reranking-model.ts @@ -0,0 +1,14 @@ +import { + RerankingModelV1, + RerankingModelV1DocumentIndex, +} from '@ai-sdk/provider'; + +/** +Reranking model that is used by the AI SDK Core functions. +*/ +export type RerankingModel = RerankingModelV1; + +/** +RerankedDocumentIndex is a index of reranked documents. + */ +export type RerankedDocumentIndex = RerankingModelV1DocumentIndex; diff --git a/packages/ai/core/types/usage.ts b/packages/ai/core/types/usage.ts index 1e00c03b294b..50502662cdd8 100644 --- a/packages/ai/core/types/usage.ts +++ b/packages/ai/core/types/usage.ts @@ -42,6 +42,16 @@ export function calculateLanguageModelUsage({ }; } +/** +Represents the number of tokens used in a reranking model. + */ +export type RerankingModelUsage = { + /** +The number of tokens used for reranking. + */ + tokens: number; +}; + export function addLanguageModelUsage( usage1: LanguageModelUsage, usage2: LanguageModelUsage, diff --git a/packages/amazon-bedrock/src/bedrock-provider.ts b/packages/amazon-bedrock/src/bedrock-provider.ts index a8e697bd50b6..57e64d5b1712 100644 --- a/packages/amazon-bedrock/src/bedrock-provider.ts +++ b/packages/amazon-bedrock/src/bedrock-provider.ts @@ -1,7 +1,9 @@ import { EmbeddingModelV1, LanguageModelV1, + NoSuchModelError, ProviderV1, + RerankingModelV1, } from '@ai-sdk/provider'; import { generateId, @@ -22,6 +24,11 @@ import { BedrockEmbeddingModelId, BedrockEmbeddingSettings, } from './bedrock-embedding-settings'; +import { BedrockRerankingModel } from './bedrock-reranking-model'; +import { + BedrockRerankingModelId, + BedrockRerankingSettings, +} from './bedrock-reranking-settings'; export interface AmazonBedrockProviderSettings { region?: string; @@ -55,6 +62,11 @@ export interface AmazonBedrockProvider extends ProviderV1 { modelId: BedrockEmbeddingModelId, settings?: BedrockEmbeddingSettings, ): EmbeddingModelV1; + + reranking( + modelId: BedrockRerankingModelId, + settings?: BedrockRerankingSettings, + ): RerankingModelV1; } /** @@ -123,10 +135,20 @@ export function createAmazonBedrock( client: createBedrockRuntimeClient(), }); + const createRerankingModel = ( + modelId: string, + settings: BedrockEmbeddingSettings = {}, + ) => + new BedrockRerankingModel(modelId, settings, { + client: createBedrockRuntimeClient(), + }); + provider.languageModel = createChatModel; provider.embedding = createEmbeddingModel; provider.textEmbedding = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; + provider.reranking = createRerankingModel; + provider.rerankingModel = createRerankingModel; return provider as AmazonBedrockProvider; } diff --git a/packages/amazon-bedrock/src/bedrock-reranking-model.test.ts b/packages/amazon-bedrock/src/bedrock-reranking-model.test.ts new file mode 100644 index 000000000000..00a42ccb633b --- /dev/null +++ b/packages/amazon-bedrock/src/bedrock-reranking-model.test.ts @@ -0,0 +1,86 @@ +import { mockClient } from 'aws-sdk-client-mock'; +import { createAmazonBedrock } from './bedrock-provider'; +import { + BedrockRuntimeClient, + InvokeModelCommand, +} from '@aws-sdk/client-bedrock-runtime'; + +const bedrockMock = mockClient(BedrockRuntimeClient); + +const dummyDocumentsIndices = [1, 0]; +const dummyDocuments = ['sunny day at the beach', 'rainy day in the city']; + +const provider = createAmazonBedrock({ + region: 'us-east-1', + accessKeyId: 'test-access-key', + secretAccessKey: 'test-secret-key', + sessionToken: 'test-token-key', +}); +const model = provider.reranking('amazon.rerank-v1'); + +describe('doRerank', () => { + beforeEach(() => { + bedrockMock.reset(); + }); + + it('should rerank documents', async () => { + const mockResponse = { + results: [ + { + index: 1, + score: 0.8, + }, + { + index: 0, + score: 0.7, + }, + ], + }; + + bedrockMock.on(InvokeModelCommand).resolves({ + //@ts-ignore + body: new TextEncoder().encode(JSON.stringify(mockResponse)), + }); + + const { rerankedIndices } = await model.doRerank({ + values: dummyDocuments, + query: 'rainy day', + topK: 2, + }); + + expect(rerankedIndices).toStrictEqual(dummyDocumentsIndices); + }); + + it('should rerank documents and return documents', async () => { + const mockResponse = { + results: [ + { + index: 1, + score: 0.8, + }, + { + index: 0, + score: 0.7, + }, + ], + }; + + bedrockMock.on(InvokeModelCommand).resolves({ + //@ts-ignore + body: new TextEncoder().encode(JSON.stringify(mockResponse)), + }); + + const { rerankedIndices, rerankedDocuments } = await model.doRerank({ + values: dummyDocuments, + query: 'rainy day', + topK: 2, + returnDocuments: true, + }); + + expect(rerankedDocuments).toStrictEqual( + dummyDocumentsIndices.map(index => dummyDocuments[index]), + ); + + expect(rerankedIndices).toStrictEqual(dummyDocumentsIndices); + }); +}); diff --git a/packages/amazon-bedrock/src/bedrock-reranking-model.ts b/packages/amazon-bedrock/src/bedrock-reranking-model.ts new file mode 100644 index 000000000000..3e55cb3ee36b --- /dev/null +++ b/packages/amazon-bedrock/src/bedrock-reranking-model.ts @@ -0,0 +1,72 @@ +import { RerankingModelV1 } from '@ai-sdk/provider'; +import { + BedrockRuntimeClient, + InvokeModelCommand, +} from '@aws-sdk/client-bedrock-runtime'; + +import { + BedrockRerankingModelId, + BedrockRerankingSettings, +} from './bedrock-reranking-settings'; + +type BedrockRerankingConfig = { + client: BedrockRuntimeClient; +}; + +export class BedrockRerankingModel implements RerankingModelV1 { + readonly specificationVersion = 'v1'; + readonly modelId: BedrockRerankingModelId; + readonly provider = 'amazon-bedrock'; + readonly maxDocumentsPerCall = undefined; + readonly supportsParallelCalls = true; + readonly returnInput = true; + private readonly config: BedrockRerankingConfig; + private readonly settings: BedrockRerankingSettings; + + constructor( + modelId: BedrockRerankingModelId, + settings: BedrockRerankingSettings, + config: BedrockRerankingConfig, + ) { + this.modelId = modelId; + this.settings = settings; + this.config = config; + } + + async doRerank({ + values, + query, + topK, + returnDocuments, + }: Parameters['doRerank']>[0]): Promise< + Awaited['doRerank']>> + > { + const payload = { + query: query, + documents: values, + top_n: topK, + }; + + const command = new InvokeModelCommand({ + contentType: 'application/json', + body: JSON.stringify(payload), + modelId: this.modelId, + }); + + const rawResponse = await this.config.client.send(command); + + const parsed: { + results: { + index: number; + relevance_score: number; + }[]; + } = JSON.parse(new TextDecoder().decode(rawResponse.body)); + + return { + rerankedIndices: parsed.results.map(result => result.index), + rerankedDocuments: returnDocuments + ? parsed.results.map(result => values[result.index]) + : undefined, + }; + } +} diff --git a/packages/amazon-bedrock/src/bedrock-reranking-settings.ts b/packages/amazon-bedrock/src/bedrock-reranking-settings.ts new file mode 100644 index 000000000000..9d6aca541131 --- /dev/null +++ b/packages/amazon-bedrock/src/bedrock-reranking-settings.ts @@ -0,0 +1,6 @@ +export type BedrockRerankingModelId = + | 'amazon.rerank-v1' + | 'cohere.rerank-v3-5:0' + | (string & {}); + +export interface BedrockRerankingSettings {} diff --git a/packages/anthropic/src/anthropic-provider.ts b/packages/anthropic/src/anthropic-provider.ts index 10c806b3ae01..af0c876eed77 100644 --- a/packages/anthropic/src/anthropic-provider.ts +++ b/packages/anthropic/src/anthropic-provider.ts @@ -130,6 +130,9 @@ export function createAnthropic( provider.textEmbeddingModel = (modelId: string) => { throw new NoSuchModelError({ modelId, modelType: 'textEmbeddingModel' }); }; + provider.rerankingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'rerankingModel' }); + }; provider.tools = anthropicTools; diff --git a/packages/azure/src/azure-openai-provider.ts b/packages/azure/src/azure-openai-provider.ts index 521ab3c8ec1e..bd00f2a130a0 100644 --- a/packages/azure/src/azure-openai-provider.ts +++ b/packages/azure/src/azure-openai-provider.ts @@ -9,6 +9,7 @@ import { import { EmbeddingModelV1, LanguageModelV1, + NoSuchModelError, ProviderV1, } from '@ai-sdk/provider'; import { FetchFunction, loadApiKey, loadSetting } from '@ai-sdk/provider-utils'; @@ -183,6 +184,9 @@ export function createAzure( provider.embedding = createEmbeddingModel; provider.textEmbedding = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; + provider.rerankingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'rerankingModel' }); + }; return provider as AzureOpenAIProvider; } diff --git a/packages/cohere/src/cohere-provider.ts b/packages/cohere/src/cohere-provider.ts index 6a0673a2308d..0dbf5822d174 100644 --- a/packages/cohere/src/cohere-provider.ts +++ b/packages/cohere/src/cohere-provider.ts @@ -2,7 +2,9 @@ import { EmbeddingModelV1, LanguageModelV1, ProviderV1, + RerankingModelV1, } from '@ai-sdk/provider'; + import { FetchFunction, loadApiKey, @@ -15,6 +17,11 @@ import { CohereEmbeddingModelId, CohereEmbeddingSettings, } from './cohere-embedding-settings'; +import { + CohereRerankingModelId, + CohereRerankingSettings, +} from './cohere-reranking-settings'; +import { CohereRerankingModel } from './cohere-reranking-model'; export interface CohereProvider extends ProviderV1 { (modelId: CohereChatModelId, settings?: CohereChatSettings): LanguageModelV1; @@ -36,6 +43,16 @@ Creates a model for text generation. modelId: CohereEmbeddingModelId, settings?: CohereEmbeddingSettings, ): EmbeddingModelV1; + + reranking( + modelId: CohereRerankingModelId, + settings?: CohereRerankingSettings, + ): RerankingModelV1; + + rerankingModel( + modelId: CohereRerankingModelId, + settings?: CohereRerankingSettings, + ): RerankingModelV1; } export interface CohereProviderSettings { @@ -103,6 +120,17 @@ export function createCohere( fetch: options.fetch, }); + const createRerankingModel = ( + modelId: CohereRerankingModelId, + settings: CohereRerankingSettings = {}, + ) => + new CohereRerankingModel(modelId, settings, { + provider: 'cohere.reranking', + baseURL, + headers: getHeaders, + fetch: options.fetch, + }); + const provider = function ( modelId: CohereChatModelId, settings?: CohereChatSettings, @@ -119,6 +147,8 @@ export function createCohere( provider.languageModel = createChatModel; provider.embedding = createTextEmbeddingModel; provider.textEmbeddingModel = createTextEmbeddingModel; + provider.reranking = createRerankingModel; + provider.rerankingModel = createRerankingModel; return provider as CohereProvider; } diff --git a/packages/cohere/src/cohere-reranking-model.test.ts b/packages/cohere/src/cohere-reranking-model.test.ts new file mode 100644 index 000000000000..7af3ff69000a --- /dev/null +++ b/packages/cohere/src/cohere-reranking-model.test.ts @@ -0,0 +1,181 @@ +import { createCohere } from './cohere-provider'; +import { JsonTestServer } from '@ai-sdk/provider-utils/test'; +import { RerankingModelV1DocumentIndex } from '@ai-sdk/provider'; + +const dummyDocumentsIndices = [1, 0]; +const dummyDocuments = ['sunny day at the beach', 'rainy day in the city']; + +const provider = createCohere({ + baseURL: 'https://api.cohere.com/v1', + apiKey: 'test-api-key', +}); +const model = provider.reranking('rerank-english-v3.0'); + +describe('doRerank', () => { + const server = new JsonTestServer('https://api.cohere.com/v1/rerank'); + + server.setupTestEnvironment(); + + function prepareJsonResponse({ + rerankedIndices = dummyDocumentsIndices, + meta = { billed_units: { input_tokens: 8 } }, + }: { + rerankedIndices?: RerankingModelV1DocumentIndex[]; + meta?: { billed_units: { input_tokens: number } }; + } = {}) { + server.responseBodyJson = { + id: 'test-id', + results: rerankedIndices.map(index => ({ + index, + document: { text: dummyDocuments[index] }, + })), + meta, + }; + } + + it('should rerank documents', async () => { + prepareJsonResponse(); + + const { rerankedIndices } = await model.doRerank({ + values: dummyDocuments, + query: 'rainy day', + topK: 2, + }); + + expect(rerankedIndices).toStrictEqual(dummyDocumentsIndices); + }); + + it('should rerank documents and return documents', async () => { + prepareJsonResponse(); + + const { rerankedIndices, rerankedDocuments } = await model.doRerank({ + values: dummyDocuments, + query: 'rainy day', + topK: 2, + returnDocuments: true, + }); + + expect(rerankedDocuments).toStrictEqual( + dummyDocumentsIndices.map(index => dummyDocuments[index]), + ); + + expect(rerankedIndices).toStrictEqual(dummyDocumentsIndices); + }); + + it('should expose the raw response headers', async () => { + prepareJsonResponse(); + + server.responseHeaders = { + 'test-header': 'test-value', + }; + + const { rawResponse } = await model.doRerank({ + values: dummyDocuments, + query: 'rainy day', + topK: 2, + }); + + expect(rawResponse?.headers).toStrictEqual({ + // default headers: + 'content-length': '184', + 'content-type': 'application/json', + + // custom header + 'test-header': 'test-value', + }); + }); + + it('should extract usage', async () => { + prepareJsonResponse({ + meta: { billed_units: { input_tokens: 20 } }, + }); + + const { usage } = await model.doRerank({ + values: dummyDocuments, + query: 'rainy day', + topK: 2, + }); + + expect(usage).toStrictEqual({ tokens: 20 }); + }); + + it('should pass the model and the values', async () => { + prepareJsonResponse(); + + await model.doRerank({ + values: dummyDocuments, + query: 'rainy day', + topK: 2, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'rerank-english-v3.0', + documents: dummyDocuments, + query: 'rainy day', + top_n: 2, + }); + }); + + it('should pass the input_type setting', async () => { + prepareJsonResponse(); + + await provider + .reranking('rerank-english-v3.0', { + max_chunks_per_document: 2, + }) + .doRerank({ values: dummyDocuments, query: 'rainy day', topK: 2 }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'rerank-english-v3.0', + documents: dummyDocuments, + query: 'rainy day', + top_n: 2, + max_chunks_per_doc: 2, + }); + + await provider.reranking('rerank-english-v3.0').doRerank({ + values: dummyDocuments, + query: 'rainy day', + topK: 2, + returnDocuments: true, + }); + + expect(await server.getRequestBodyJson()).toStrictEqual({ + model: 'rerank-english-v3.0', + documents: dummyDocuments, + query: 'rainy day', + top_n: 2, + return_documents: true, + }); + }); + + it('should pass headers', async () => { + prepareJsonResponse(); + + const provider = createCohere({ + baseURL: 'https://api.cohere.com/v1', + apiKey: 'test-api-key', + headers: { + 'Custom-Provider-Header': 'provider-header-value', + }, + }); + + await provider.reranking('rerank-english-v3.0').doRerank({ + values: dummyDocuments, + query: 'rainy day', + topK: 2, + headers: { + 'Custom-Request-Header': 'request-header-value', + }, + }); + + const requestHeaders = await server.getRequestHeaders(); + + expect(requestHeaders).toStrictEqual({ + authorization: 'Bearer test-api-key', + 'content-type': 'application/json', + 'custom-provider-header': 'provider-header-value', + 'custom-request-header': 'request-header-value', + }); + }); +}); diff --git a/packages/cohere/src/cohere-reranking-model.ts b/packages/cohere/src/cohere-reranking-model.ts new file mode 100644 index 000000000000..ebda416bba18 --- /dev/null +++ b/packages/cohere/src/cohere-reranking-model.ts @@ -0,0 +1,123 @@ +import { + RerankingModelV1, + TooManyDocumentsForRerankingError, +} from '@ai-sdk/provider'; + +import { + combineHeaders, + createJsonResponseHandler, + FetchFunction, + postJsonToApi, +} from '@ai-sdk/provider-utils'; + +import { z } from 'zod'; +import { cohereFailedResponseHandler } from './cohere-error'; +import { + CohereRerankingModelId, + CohereRerankingSettings, +} from './cohere-reranking-settings'; + +type CohereRerankingConfig = { + provider: string; + baseURL: string; + headers: () => Record; + fetch?: FetchFunction; +}; + +export class CohereRerankingModel implements RerankingModelV1 { + readonly specificationVersion = 'v1'; + readonly modelId: CohereRerankingModelId; + + readonly maxDocumentsPerCall = 10000; + readonly supportsParallelCalls = true; + readonly returnInput = false; + + private readonly config: CohereRerankingConfig; + private readonly settings: CohereRerankingSettings; + + constructor( + modelId: CohereRerankingModelId, + settings: CohereRerankingSettings, + config: CohereRerankingConfig, + ) { + this.modelId = modelId; + this.settings = settings; + this.config = config; + } + + get provider(): string { + return this.config.provider; + } + + // current implementation is based on v1 of the API: https://docs.cohere.com/v1/reference/rerank + async doRerank({ + values, + headers, + query, + topK, + returnDocuments, + abortSignal, + }: Parameters['doRerank']>[0]): Promise< + Awaited['doRerank']>> + > { + const totalMaxChunks = this.settings.max_chunks_per_document ?? 1; + // The total max chunks (length of documents * max_chunks_per_doc) must be less than 10000. + if (values.length * totalMaxChunks > this.maxDocumentsPerCall) { + throw new TooManyDocumentsForRerankingError({ + provider: this.provider, + modelId: this.modelId, + maxDocumentsPerCall: this.maxDocumentsPerCall, + documents: values, + }); + } + + const { responseHeaders, value: response } = await postJsonToApi({ + url: `${this.config.baseURL}/rerank`, + headers: combineHeaders(this.config.headers(), headers), + body: { + model: this.modelId, + documents: values, + query: query, + top_n: topK ?? values.length, + return_documents: returnDocuments, + max_chunks_per_doc: this.settings.max_chunks_per_document, + }, + failedResponseHandler: cohereFailedResponseHandler, + successfulResponseHandler: createJsonResponseHandler( + cohereRerankingResponseSchema, + ), + abortSignal, + fetch: this.config.fetch, + }); + + return { + rerankedIndices: response.results.map(result => result.index), + rerankedDocuments: returnDocuments + ? response.results.map(result => result.document?.text ?? '') + : undefined, + usage: { tokens: response.meta.billed_units.input_tokens }, + rawResponse: { headers: responseHeaders }, + }; + } +} + +// minimal version of the schema, focussed on what is needed for the implementation +// this approach limits breakages when the API changes and increases efficiency +const cohereRerankingResponseSchema = z.object({ + results: z.array( + z.object({ + index: z.number(), + document: z + .object({ + text: z.string(), + }) + .catchall(z.any()) + .optional(), + }), + ), + meta: z.object({ + billed_units: z.object({ + input_tokens: z.number(), + }), + }), +}); diff --git a/packages/cohere/src/cohere-reranking-settings.ts b/packages/cohere/src/cohere-reranking-settings.ts new file mode 100644 index 000000000000..65f93e69eb72 --- /dev/null +++ b/packages/cohere/src/cohere-reranking-settings.ts @@ -0,0 +1,22 @@ +export type CohereRerankingModelId = + | 'rerank-english-v3.0' + | 'rerank-multilingual-v3.0' + | 'rerank-english-v2.0' + | 'rerank-multilingual-v2.0' + | (string & {}); + +export interface CohereRerankingSettings { + /** + * The maximum number of chunks to produce internally from a document + * + */ + max_chunks_per_document?: number; + + /** + * If a JSON object is provided, you can specify which keys you would like to have considered for reranking. + * The model will rerank based on order of the fields passed in (i.e. rank_fields=[‘title’,‘author’,‘text’] will rerank using the values in title, author, text sequentially. + * If the length of title, author, and text exceeds the context length of the model, the chunking will not re-consider earlier fields). + * If not provided, the model will use the default text field for ranking. + */ + rank_fields?: string[]; +} diff --git a/packages/deepseek/src/deepseek-provider.ts b/packages/deepseek/src/deepseek-provider.ts index 76a92ef64ee7..bac5cc4ae8c8 100644 --- a/packages/deepseek/src/deepseek-provider.ts +++ b/packages/deepseek/src/deepseek-provider.ts @@ -98,6 +98,9 @@ export function createDeepSeek( provider.textEmbeddingModel = (modelId: string) => { throw new NoSuchModelError({ modelId, modelType: 'textEmbeddingModel' }); }; + provider.rerankingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'rerankingModel' }); + }; return provider as DeepSeekProvider; } diff --git a/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider.ts b/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider.ts index 0db7561d44de..aa5b396b4704 100644 --- a/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider.ts +++ b/packages/google-vertex/src/anthropic/google-vertex-anthropic-provider.ts @@ -134,6 +134,9 @@ export function createVertexAnthropic( provider.textEmbeddingModel = (modelId: string) => { throw new NoSuchModelError({ modelId, modelType: 'textEmbeddingModel' }); }; + provider.rerankingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'rerankingModel' }); + }; provider.tools = anthropicTools; diff --git a/packages/google-vertex/src/google-vertex-provider.ts b/packages/google-vertex/src/google-vertex-provider.ts index 9b7eac4b00a9..6dee6a8ab501 100644 --- a/packages/google-vertex/src/google-vertex-provider.ts +++ b/packages/google-vertex/src/google-vertex-provider.ts @@ -1,4 +1,9 @@ -import { LanguageModelV1, ProviderV1, ImageModelV1 } from '@ai-sdk/provider'; +import { + LanguageModelV1, + ProviderV1, + ImageModelV1, + NoSuchModelError, +} from '@ai-sdk/provider'; import { FetchFunction, generateId, @@ -155,6 +160,9 @@ export function createVertex( provider.languageModel = createChatModel; provider.textEmbeddingModel = createEmbeddingModel; + provider.rerankingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'rerankingModel' }); + }; provider.image = createImageModel; return provider as GoogleVertexProvider; diff --git a/packages/google/src/google-provider.ts b/packages/google/src/google-provider.ts index 6925bf811753..fbfe5aca92aa 100644 --- a/packages/google/src/google-provider.ts +++ b/packages/google/src/google-provider.ts @@ -17,6 +17,7 @@ import { import { EmbeddingModelV1, LanguageModelV1, + NoSuchModelError, ProviderV1, } from '@ai-sdk/provider'; @@ -157,6 +158,9 @@ export function createGoogleGenerativeAI( provider.embedding = createEmbeddingModel; provider.textEmbedding = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; + provider.rerankingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'rerankingModel' }); + }; return provider as GoogleGenerativeAIProvider; } diff --git a/packages/groq/src/groq-provider.ts b/packages/groq/src/groq-provider.ts index 6b62be7733fe..1dd16d48be72 100644 --- a/packages/groq/src/groq-provider.ts +++ b/packages/groq/src/groq-provider.ts @@ -101,6 +101,9 @@ export function createGroq(options: GroqProviderSettings = {}): GroqProvider { provider.textEmbeddingModel = (modelId: string) => { throw new NoSuchModelError({ modelId, modelType: 'textEmbeddingModel' }); }; + provider.rerankingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'rerankingModel' }); + }; return provider as GroqProvider; } diff --git a/packages/mistral/src/mistral-provider.ts b/packages/mistral/src/mistral-provider.ts index 745d3d701511..a9b7a2d1d2ea 100644 --- a/packages/mistral/src/mistral-provider.ts +++ b/packages/mistral/src/mistral-provider.ts @@ -1,6 +1,7 @@ import { EmbeddingModelV1, LanguageModelV1, + NoSuchModelError, ProviderV1, } from '@ai-sdk/provider'; import { @@ -146,6 +147,9 @@ export function createMistral( provider.embedding = createEmbeddingModel; provider.textEmbedding = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; + provider.rerankingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'rerankingModel' }); + }; return provider as MistralProvider; } diff --git a/packages/openai-compatible/src/openai-compatible-provider.ts b/packages/openai-compatible/src/openai-compatible-provider.ts index 9b04b23c5c78..9a1335f35705 100644 --- a/packages/openai-compatible/src/openai-compatible-provider.ts +++ b/packages/openai-compatible/src/openai-compatible-provider.ts @@ -1,6 +1,7 @@ import { EmbeddingModelV1, LanguageModelV1, + NoSuchModelError, ProviderV1, } from '@ai-sdk/provider'; import { @@ -163,6 +164,9 @@ export function createOpenAICompatible< provider.chatModel = createChatModel; provider.completionModel = createCompletionModel; provider.textEmbeddingModel = createEmbeddingModel; + provider.rerankingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'rerankingModel' }); + }; return provider as OpenAICompatibleProvider< CHAT_MODEL_IDS, diff --git a/packages/openai/src/openai-provider.ts b/packages/openai/src/openai-provider.ts index 18ca8919db4f..f83e6e031e34 100644 --- a/packages/openai/src/openai-provider.ts +++ b/packages/openai/src/openai-provider.ts @@ -2,6 +2,7 @@ import { EmbeddingModelV1, ImageModelV1, LanguageModelV1, + NoSuchModelError, ProviderV1, } from '@ai-sdk/provider'; import { @@ -237,6 +238,9 @@ export function createOpenAI( provider.textEmbedding = createEmbeddingModel; provider.textEmbeddingModel = createEmbeddingModel; provider.image = createImageModel; + provider.rerankingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'rerankingModel' }); + }; return provider as OpenAIProvider; } diff --git a/packages/provider/src/errors/index.ts b/packages/provider/src/errors/index.ts index 465022603e5a..836469413796 100644 --- a/packages/provider/src/errors/index.ts +++ b/packages/provider/src/errors/index.ts @@ -11,5 +11,6 @@ export { LoadSettingError } from './load-setting-error'; export { NoContentGeneratedError } from './no-content-generated-error'; export { NoSuchModelError } from './no-such-model-error'; export { TooManyEmbeddingValuesForCallError } from './too-many-embedding-values-for-call-error'; +export { TooManyDocumentsForRerankingError } from './too-many-documents-for-reranking-error'; export { TypeValidationError } from './type-validation-error'; export { UnsupportedFunctionalityError } from './unsupported-functionality-error'; diff --git a/packages/provider/src/errors/no-such-model-error.ts b/packages/provider/src/errors/no-such-model-error.ts index 9946fa445c32..d5cdcd90f425 100644 --- a/packages/provider/src/errors/no-such-model-error.ts +++ b/packages/provider/src/errors/no-such-model-error.ts @@ -8,7 +8,7 @@ export class NoSuchModelError extends AISDKError { private readonly [symbol] = true; // used in isInstance readonly modelId: string; - readonly modelType: 'languageModel' | 'textEmbeddingModel'; + readonly modelType: 'languageModel' | 'textEmbeddingModel' | 'rerankingModel'; constructor({ errorName = name, @@ -18,7 +18,7 @@ export class NoSuchModelError extends AISDKError { }: { errorName?: string; modelId: string; - modelType: 'languageModel' | 'textEmbeddingModel'; + modelType: 'languageModel' | 'textEmbeddingModel' | 'rerankingModel'; message?: string; }) { super({ name: errorName, message }); diff --git a/packages/provider/src/errors/too-many-documents-for-reranking-error.ts b/packages/provider/src/errors/too-many-documents-for-reranking-error.ts new file mode 100644 index 000000000000..7bd8ed4a5d00 --- /dev/null +++ b/packages/provider/src/errors/too-many-documents-for-reranking-error.ts @@ -0,0 +1,40 @@ +import { AISDKError } from './ai-sdk-error'; + +const name = 'AI_TooManyDocumentsForRerankingError'; +const marker = `vercel.ai.error.${name}`; +const symbol = Symbol.for(marker); + +export class TooManyDocumentsForRerankingError extends AISDKError { + private readonly [symbol] = true; // used in isInstance + + readonly provider: string; + readonly modelId: string; + readonly maxDocumentsPerCall: number; + readonly documents: Array; + + constructor(options: { + provider: string; + modelId: string; + maxDocumentsPerCall: number; + documents: Array; + }) { + super({ + name, + message: + `Too many documents for a single reranking call. ` + + `The ${options.provider} model "${options.modelId}" can only rerank up to ` + + `${options.maxDocumentsPerCall} documents per call, but ${options.documents.length} documents were provided.`, + }); + + this.provider = options.provider; + this.modelId = options.modelId; + this.maxDocumentsPerCall = options.maxDocumentsPerCall; + this.documents = options.documents; + } + + static isInstance( + error: unknown, + ): error is TooManyDocumentsForRerankingError { + return AISDKError.hasMarker(error, marker); + } +} diff --git a/packages/provider/src/index.ts b/packages/provider/src/index.ts index 1e258cf9baa8..bf3a0c781921 100644 --- a/packages/provider/src/index.ts +++ b/packages/provider/src/index.ts @@ -1,4 +1,5 @@ export * from './embedding-model/index'; +export * from './reranking-model/index'; export * from './errors/index'; export * from './image-model/index'; export * from './json-value/index'; diff --git a/packages/provider/src/provider/v1/provider-v1.ts b/packages/provider/src/provider/v1/provider-v1.ts index e197f53ff388..7be4860e7cd8 100644 --- a/packages/provider/src/provider/v1/provider-v1.ts +++ b/packages/provider/src/provider/v1/provider-v1.ts @@ -1,8 +1,9 @@ import { EmbeddingModelV1 } from '../../embedding-model/v1/embedding-model-v1'; import { LanguageModelV1 } from '../../language-model/v1/language-model-v1'; +import { RerankingModelV1 } from '../../reranking-model/v1/reranking-model-v1'; /** - * Provider for language and text embedding models. + * Provider for language, text embedding models and reranking models. */ export interface ProviderV1 { /** @@ -28,4 +29,16 @@ The model id is then passed to the provider function to get the model. @throws {NoSuchModelError} If no such model exists. */ textEmbeddingModel(modelId: string): EmbeddingModelV1; + + /** +Returns the reranking model with the given id. +The model id is then passed to the provider function to get the model. + +@param {string} modelId - The id of the model to return. + +@returns {RerankingModel} The reranking model associated with the id + +@throws {NoSuchModelError} If no such model exists. + */ + rerankingModel(modelId: string): RerankingModelV1; } diff --git a/packages/provider/src/reranking-model/index.ts b/packages/provider/src/reranking-model/index.ts new file mode 100644 index 000000000000..e69fb44d6835 --- /dev/null +++ b/packages/provider/src/reranking-model/index.ts @@ -0,0 +1 @@ +export * from './v1/index'; diff --git a/packages/provider/src/reranking-model/v1/index.ts b/packages/provider/src/reranking-model/v1/index.ts new file mode 100644 index 000000000000..20a40470e212 --- /dev/null +++ b/packages/provider/src/reranking-model/v1/index.ts @@ -0,0 +1,2 @@ +export * from './reranking-model-v1'; +export * from './reranking-model-v1-reranking'; diff --git a/packages/provider/src/reranking-model/v1/reranking-model-v1-reranking.ts b/packages/provider/src/reranking-model/v1/reranking-model-v1-reranking.ts new file mode 100644 index 000000000000..20088a2afd96 --- /dev/null +++ b/packages/provider/src/reranking-model/v1/reranking-model-v1-reranking.ts @@ -0,0 +1,5 @@ +/** +A RerankIndex is a index of reranking documents. +It is e.g. used to represent a index of original documents that are reranked. + */ +export type RerankingModelV1DocumentIndex = number; diff --git a/packages/provider/src/reranking-model/v1/reranking-model-v1.ts b/packages/provider/src/reranking-model/v1/reranking-model-v1.ts new file mode 100644 index 000000000000..5de7d02ddba3 --- /dev/null +++ b/packages/provider/src/reranking-model/v1/reranking-model-v1.ts @@ -0,0 +1,107 @@ +import { RerankingModelV1DocumentIndex } from './reranking-model-v1-reranking'; + +/** +Experimental: Specification for a reranking model that implements the reranking model +interface version 1. + +VALUE is the type of the values that the model can rerank. + */ +export type RerankingModelV1 = { + /** +The reranking model must specify which reranking model interface +version it implements. This will allow us to evolve the reranking +model interface and retain backwards compatibility. The different +implementation versions can be handled as a discriminated union +on our side. + */ + readonly specificationVersion: 'v1'; + + /** +Name of the provider for logging purposes. + */ + readonly provider: string; + + /** +Provider-specific model ID for logging purposes. + */ + readonly modelId: string; + + /** +Limit of how many documents can be reranking in a single API call. + */ + readonly maxDocumentsPerCall: number | undefined; + + /** +True if the model can handle multiple reranking calls in parallel. + */ + readonly supportsParallelCalls: boolean; + + /** +True if the model can return the input values in the response. + */ + readonly returnInput: boolean; + + /** +Reranking a list of documents using the query +Naming: "do" prefix to prevent accidental direct usage of the method +by the user. + */ + doRerank(options: { + /** +List of documents to rerank. + */ + values: Array; + + /** +The query is a string that represents the query to rerank the documents against. + */ + query: string; + + /** +The top-k documents after reranking. + */ + topK: number; + + /** +Return the reranked documents in the response (In same order as indices). + */ + returnDocuments?: boolean; + + /** +Abort signal for cancelling the operation. + */ + abortSignal?: AbortSignal; + + /** + Additional HTTP headers to be sent with the request. + Only applicable for HTTP-based providers. + */ + headers?: Record; + }): PromiseLike<{ + /** +Reranked document indices. + */ + rerankedIndices: Array; + + /** + * This is optional and only send if the returnDocuments flag is set to true. +Reranked documents will be in same order as indices. + */ + rerankedDocuments?: Array; + + /** +Token usage. We only have input tokens for reranking. + */ + usage?: { tokens: number }; + + /** +Optional raw response information for debugging purposes. + */ + rawResponse?: { + /** +Response headers. + */ + headers?: Record; + }; + }>; +}; diff --git a/packages/xai/src/xai-provider.ts b/packages/xai/src/xai-provider.ts index 5f82771599af..b494860c0ac4 100644 --- a/packages/xai/src/xai-provider.ts +++ b/packages/xai/src/xai-provider.ts @@ -113,6 +113,9 @@ export function createXai(options: XaiProviderSettings = {}): XaiProvider { provider.textEmbeddingModel = (modelId: string) => { throw new NoSuchModelError({ modelId, modelType: 'textEmbeddingModel' }); }; + provider.rerankingModel = (modelId: string) => { + throw new NoSuchModelError({ modelId, modelType: 'rerankingModel' }); + }; return provider as XaiProvider; }