diff --git a/.changeset/serious-jars-train.md b/.changeset/serious-jars-train.md new file mode 100644 index 000000000000..5fde187d3fb0 --- /dev/null +++ b/.changeset/serious-jars-train.md @@ -0,0 +1,6 @@ +--- +'@ai-sdk/google-vertex': patch +'@ai-sdk/google': patch +--- + +fix(vertex): allow 'vertex' as a key for providerOptions diff --git a/content/docs/08-migration-guides/24-migration-guide-6-0.mdx b/content/docs/08-migration-guides/24-migration-guide-6-0.mdx index 235cd4761ae6..b830e8b3b915 100644 --- a/content/docs/08-migration-guides/24-migration-guide-6-0.mdx +++ b/content/docs/08-migration-guides/24-migration-guide-6-0.mdx @@ -478,6 +478,52 @@ const result = await generateObject({ }); ``` +### Google Vertex + +#### `providerMetadata` and `providerOptions` Key + +The `@ai-sdk/google-vertex` provider now uses `vertex` as the key for `providerMetadata` and `providerOptions` instead of `google`. The `google` key is still supported for `providerOptions` input, but resulting `providerMetadata` output now uses `vertex`. + +```tsx filename="AI SDK 5" +import { vertex } from '@ai-sdk/google-vertex'; +import { generateText } from 'ai'; + +const result = await generateText({ + model: vertex('gemini-2.5-flash'), + providerOptions: { + google: { + safetySettings: [ + /* ... */ + ], + }, // Used 'google' key + }, + prompt: 'Hello', +}); + +// Accessed metadata via 'google' key +console.log(result.providerMetadata?.google?.safetyRatings); +``` + +```tsx filename="AI SDK 6" +import { vertex } from '@ai-sdk/google-vertex'; +import { generateText } from 'ai'; + +const result = await generateText({ + model: vertex('gemini-2.5-flash'), + providerOptions: { + vertex: { + safetySettings: [ + /* ... */ + ], + }, // Now uses 'vertex' key + }, + prompt: 'Hello', +}); + +// Access metadata via 'vertex' key +console.log(result.providerMetadata?.vertex?.safetyRatings); +``` + ## `ai/test` ### Mock Classes diff --git a/examples/ai-core/src/generate-text/google-vertex-safety.ts b/examples/ai-core/src/generate-text/google-vertex-safety.ts index 6e07a01b8225..7f7d56f2645a 100644 --- a/examples/ai-core/src/generate-text/google-vertex-safety.ts +++ b/examples/ai-core/src/generate-text/google-vertex-safety.ts @@ -4,9 +4,9 @@ import 'dotenv/config'; async function main() { const result = await generateText({ - model: vertex('gemini-1.5-pro'), + model: vertex('gemini-2.5-flash'), providerOptions: { - google: { + vertex: { safetySettings: [ { category: 'HARM_CATEGORY_UNSPECIFIED', @@ -22,6 +22,8 @@ async function main() { console.log(); console.log('Token usage:', result.usage); console.log('Finish reason:', result.finishReason); + console.log(); + console.log('Request body:', result.request?.body); } main().catch(console.error); diff --git a/packages/google-vertex/src/google-vertex-embedding-model.test.ts b/packages/google-vertex/src/google-vertex-embedding-model.test.ts index d3941a1bdc4b..387a0b2ae24f 100644 --- a/packages/google-vertex/src/google-vertex-embedding-model.test.ts +++ b/packages/google-vertex/src/google-vertex-embedding-model.test.ts @@ -138,6 +138,27 @@ describe('GoogleVertexEmbeddingModel', () => { }); }); + it('should accept vertex as provider options key', async () => { + prepareJsonResponse(); + + await model.doEmbed({ + values: testValues, + providerOptions: { vertex: mockProviderOptions }, + }); + + expect(await server.calls[0].requestBodyJson).toStrictEqual({ + instances: testValues.map(value => ({ + content: value, + task_type: mockProviderOptions.taskType, + title: mockProviderOptions.title, + })), + parameters: { + outputDimensionality: mockProviderOptions.outputDimensionality, + autoTruncate: mockProviderOptions.autoTruncate, + }, + }); + }); + it('should pass the taskType setting in instances', async () => { prepareJsonResponse(); diff --git a/packages/google-vertex/src/google-vertex-embedding-model.ts b/packages/google-vertex/src/google-vertex-embedding-model.ts index a81a1478fe55..0f81f87aca4f 100644 --- a/packages/google-vertex/src/google-vertex-embedding-model.ts +++ b/packages/google-vertex/src/google-vertex-embedding-model.ts @@ -45,13 +45,21 @@ export class GoogleVertexEmbeddingModel implements EmbeddingModelV3 { }: Parameters[0]): Promise< Awaited> > { - // Parse provider options - const googleOptions = - (await parseProviderOptions({ + let googleOptions = await parseProviderOptions({ + provider: 'vertex', + providerOptions, + schema: googleVertexEmbeddingProviderOptions, + }); + + if (googleOptions == null) { + googleOptions = await parseProviderOptions({ provider: 'google', providerOptions, schema: googleVertexEmbeddingProviderOptions, - })) ?? {}; + }); + } + + googleOptions = googleOptions ?? {}; if (values.length > this.maxEmbeddingsPerCall) { throw new TooManyEmbeddingValuesForCallError({ diff --git a/packages/google/src/google-generative-ai-language-model.test.ts b/packages/google/src/google-generative-ai-language-model.test.ts index 350a6dc5466d..98e6e722e11a 100644 --- a/packages/google/src/google-generative-ai-language-model.test.ts +++ b/packages/google/src/google-generative-ai-language-model.test.ts @@ -2118,6 +2118,194 @@ describe('doGenerate', () => { }, }); }); + + describe('providerMetadata key based on provider string', () => { + it('should use "vertex" as providerMetadata key when provider includes "vertex"', async () => { + server.urls[TEST_URL_GEMINI_PRO].response = { + type: 'json-value', + body: { + candidates: [ + { + content: { parts: [{ text: 'Hello!' }], role: 'model' }, + finishReason: 'STOP', + safetyRatings: SAFETY_RATINGS, + groundingMetadata: { + webSearchQueries: ['test query'], + }, + }, + ], + promptFeedback: { safetyRatings: SAFETY_RATINGS }, + usageMetadata: { + promptTokenCount: 1, + candidatesTokenCount: 2, + totalTokenCount: 3, + }, + }, + }; + + const model = new GoogleGenerativeAILanguageModel('gemini-pro', { + provider: 'google.vertex.chat', + baseURL: 'https://generativelanguage.googleapis.com/v1beta', + headers: { 'x-goog-api-key': 'test-api-key' }, + generateId: () => 'test-id', + }); + + const { providerMetadata } = await model.doGenerate({ + prompt: TEST_PROMPT, + }); + + expect(providerMetadata).toHaveProperty('vertex'); + expect(providerMetadata).not.toHaveProperty('google'); + expect(providerMetadata?.vertex).toMatchObject({ + promptFeedback: expect.any(Object), + groundingMetadata: expect.any(Object), + safetyRatings: expect.any(Array), + }); + }); + + it('should use "google" as providerMetadata key when provider does not include "vertex"', async () => { + server.urls[TEST_URL_GEMINI_PRO].response = { + type: 'json-value', + body: { + candidates: [ + { + content: { parts: [{ text: 'Hello!' }], role: 'model' }, + finishReason: 'STOP', + safetyRatings: SAFETY_RATINGS, + groundingMetadata: { + webSearchQueries: ['test query'], + }, + }, + ], + promptFeedback: { safetyRatings: SAFETY_RATINGS }, + usageMetadata: { + promptTokenCount: 1, + candidatesTokenCount: 2, + totalTokenCount: 3, + }, + }, + }; + + const model = new GoogleGenerativeAILanguageModel('gemini-pro', { + provider: 'google.generative-ai', + baseURL: 'https://generativelanguage.googleapis.com/v1beta', + headers: { 'x-goog-api-key': 'test-api-key' }, + generateId: () => 'test-id', + //should default to 'google' + }); + + const { providerMetadata } = await model.doGenerate({ + prompt: TEST_PROMPT, + }); + + expect(providerMetadata).toHaveProperty('google'); + expect(providerMetadata).not.toHaveProperty('vertex'); + expect(providerMetadata?.google).toMatchObject({ + promptFeedback: expect.any(Object), + groundingMetadata: expect.any(Object), + safetyRatings: expect.any(Array), + }); + }); + + it('should use "vertex" as providerMetadata key in thoughtSignature content when provider includes "vertex"', async () => { + server.urls[TEST_URL_GEMINI_PRO].response = { + type: 'json-value', + body: { + candidates: [ + { + content: { + parts: [ + { + text: 'thinking...', + thought: true, + thoughtSignature: 'sig123', + }, + { text: 'Final answer' }, + ], + role: 'model', + }, + finishReason: 'STOP', + safetyRatings: SAFETY_RATINGS, + }, + ], + promptFeedback: { safetyRatings: SAFETY_RATINGS }, + usageMetadata: { + promptTokenCount: 1, + candidatesTokenCount: 2, + totalTokenCount: 3, + }, + }, + }; + + const model = new GoogleGenerativeAILanguageModel('gemini-pro', { + provider: 'google.vertex.chat', + baseURL: 'https://generativelanguage.googleapis.com/v1beta', + headers: { 'x-goog-api-key': 'test-api-key' }, + generateId: () => 'test-id', + }); + + const { content } = await model.doGenerate({ + prompt: TEST_PROMPT, + }); + + const reasoningPart = content.find(part => part.type === 'reasoning'); + expect(reasoningPart?.providerMetadata).toHaveProperty('vertex'); + expect(reasoningPart?.providerMetadata).not.toHaveProperty('google'); + expect(reasoningPart?.providerMetadata?.vertex).toMatchObject({ + thoughtSignature: 'sig123', + }); + }); + + it('should use "vertex" as providerMetadata key in tool call content when provider includes "vertex"', async () => { + server.urls[TEST_URL_GEMINI_PRO].response = { + type: 'json-value', + body: { + candidates: [ + { + content: { + parts: [ + { + functionCall: { + name: 'test-tool', + args: { value: 'test' }, + }, + thoughtSignature: 'tool_sig', + }, + ], + role: 'model', + }, + finishReason: 'STOP', + safetyRatings: SAFETY_RATINGS, + }, + ], + promptFeedback: { safetyRatings: SAFETY_RATINGS }, + usageMetadata: { + promptTokenCount: 1, + candidatesTokenCount: 2, + totalTokenCount: 3, + }, + }, + }; + + const model = new GoogleGenerativeAILanguageModel('gemini-pro', { + provider: 'google.vertex.chat', + baseURL: 'https://generativelanguage.googleapis.com/v1beta', + headers: { 'x-goog-api-key': 'test-api-key' }, + generateId: () => 'test-id', + }); + + const { content } = await model.doGenerate({ + prompt: TEST_PROMPT, + }); + + const toolCallPart = content.find(part => part.type === 'tool-call'); + expect(toolCallPart?.providerMetadata).toHaveProperty('vertex'); + expect(toolCallPart?.providerMetadata).not.toHaveProperty('google'); + expect(toolCallPart?.providerMetadata?.vertex).toMatchObject({ + thoughtSignature: 'tool_sig', + }); + }); + }); }); describe('doStream', () => { @@ -3562,6 +3750,197 @@ describe('doStream', () => { expect(chunks.filter(chunk => chunk.type === 'raw')).toHaveLength(0); }); }); + + describe('providerMetadata key based on provider string', () => { + it('should use "vertex" as providerMetadata key in finish event when provider includes "vertex"', async () => { + server.urls[TEST_URL_GEMINI_PRO].response = { + type: 'stream-chunks', + chunks: [ + `data: ${JSON.stringify({ + candidates: [ + { + content: { parts: [{ text: 'Hello' }], role: 'model' }, + }, + ], + })}\n\n`, + `data: ${JSON.stringify({ + candidates: [ + { + content: { parts: [{ text: ' World!' }], role: 'model' }, + finishReason: 'STOP', + safetyRatings: SAFETY_RATINGS, + groundingMetadata: { + webSearchQueries: ['test query'], + }, + }, + ], + usageMetadata: { + promptTokenCount: 1, + candidatesTokenCount: 2, + totalTokenCount: 3, + }, + })}\n\n`, + ], + }; + + const model = new GoogleGenerativeAILanguageModel('gemini-pro', { + provider: 'google.vertex.chat', + baseURL: 'https://generativelanguage.googleapis.com/v1beta', + headers: { 'x-goog-api-key': 'test-api-key' }, + generateId: () => 'test-id', + }); + + const { stream } = await model.doStream({ + prompt: TEST_PROMPT, + }); + + const events = await convertReadableStreamToArray(stream); + const finishEvent = events.find(event => event.type === 'finish'); + + expect(finishEvent?.type === 'finish').toBe(true); + if (finishEvent?.type === 'finish') { + expect(finishEvent.providerMetadata).toHaveProperty('vertex'); + expect(finishEvent.providerMetadata?.vertex).toMatchObject({ + promptFeedback: null, + groundingMetadata: expect.any(Object), + safetyRatings: expect.any(Array), + }); + } + }); + + it('should use "google" as providerMetadata key in finish event when provider does not include "vertex"', async () => { + server.urls[TEST_URL_GEMINI_PRO].response = { + type: 'stream-chunks', + chunks: [ + `data: ${JSON.stringify({ + candidates: [ + { + content: { parts: [{ text: 'Hello' }], role: 'model' }, + }, + ], + })}\n\n`, + `data: ${JSON.stringify({ + candidates: [ + { + content: { parts: [{ text: ' World!' }], role: 'model' }, + finishReason: 'STOP', + safetyRatings: SAFETY_RATINGS, + groundingMetadata: { + webSearchQueries: ['test query'], + }, + }, + ], + usageMetadata: { + promptTokenCount: 1, + candidatesTokenCount: 2, + totalTokenCount: 3, + }, + })}\n\n`, + ], + }; + + const model = new GoogleGenerativeAILanguageModel('gemini-pro', { + provider: 'google.generative-ai', + baseURL: 'https://generativelanguage.googleapis.com/v1beta', + headers: { 'x-goog-api-key': 'test-api-key' }, + generateId: () => 'test-id', + }); + + const { stream } = await model.doStream({ + prompt: TEST_PROMPT, + }); + + const events = await convertReadableStreamToArray(stream); + const finishEvent = events.find(event => event.type === 'finish'); + + expect(finishEvent?.type === 'finish').toBe(true); + if (finishEvent?.type === 'finish') { + expect(finishEvent.providerMetadata).toHaveProperty('google'); + expect(finishEvent.providerMetadata).not.toHaveProperty('vertex'); + expect(finishEvent.providerMetadata?.google).toMatchObject({ + promptFeedback: null, + groundingMetadata: expect.any(Object), + safetyRatings: expect.any(Array), + }); + } + }); + + it('should use "vertex" as providerMetadata key in streaming reasoning events when provider includes "vertex"', async () => { + server.urls[TEST_URL_GEMINI_PRO].response = { + type: 'stream-chunks', + chunks: [ + `data: ${JSON.stringify({ + candidates: [ + { + content: { + parts: [ + { + text: 'thinking...', + thought: true, + thoughtSignature: 'stream_sig', + }, + ], + role: 'model', + }, + }, + ], + })}\n\n`, + `data: ${JSON.stringify({ + candidates: [ + { + content: { parts: [{ text: 'Final answer' }], role: 'model' }, + finishReason: 'STOP', + safetyRatings: SAFETY_RATINGS, + }, + ], + usageMetadata: { + promptTokenCount: 1, + candidatesTokenCount: 2, + totalTokenCount: 3, + }, + })}\n\n`, + ], + }; + + const model = new GoogleGenerativeAILanguageModel('gemini-pro', { + provider: 'google.vertex.chat', + baseURL: 'https://generativelanguage.googleapis.com/v1beta', + headers: { 'x-goog-api-key': 'test-api-key' }, + generateId: () => 'test-id', + }); + + const { stream } = await model.doStream({ + prompt: TEST_PROMPT, + }); + + const events = await convertReadableStreamToArray(stream); + + const reasoningStartEvent = events.find( + event => event.type === 'reasoning-start', + ); + expect(reasoningStartEvent?.type === 'reasoning-start').toBe(true); + if (reasoningStartEvent?.type === 'reasoning-start') { + expect(reasoningStartEvent.providerMetadata).toHaveProperty('vertex'); + expect(reasoningStartEvent.providerMetadata).not.toHaveProperty( + 'google', + ); + expect(reasoningStartEvent.providerMetadata?.vertex).toMatchObject({ + thoughtSignature: 'stream_sig', + }); + } + + const reasoningDeltaEvent = events.find( + event => event.type === 'reasoning-delta', + ); + expect(reasoningDeltaEvent?.type === 'reasoning-delta').toBe(true); + if (reasoningDeltaEvent?.type === 'reasoning-delta') { + expect(reasoningDeltaEvent.providerMetadata).toHaveProperty('vertex'); + expect(reasoningDeltaEvent.providerMetadata).not.toHaveProperty( + 'google', + ); + } + }); + }); }); describe('GEMMA Model System Instruction Fix', () => { diff --git a/packages/google/src/google-generative-ai-language-model.ts b/packages/google/src/google-generative-ai-language-model.ts index 4ba3cb3747ae..f7ee8a592d57 100644 --- a/packages/google/src/google-generative-ai-language-model.ts +++ b/packages/google/src/google-generative-ai-language-model.ts @@ -94,12 +94,23 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { }: Parameters[0]) { const warnings: SharedV3Warning[] = []; - const googleOptions = await parseProviderOptions({ - provider: 'google', + const providerOptionsName = this.config.provider.includes('vertex') + ? 'vertex' + : 'google'; + let googleOptions = await parseProviderOptions({ + provider: providerOptionsName, providerOptions, schema: googleGenerativeAIProviderOptions, }); + if (googleOptions == null && providerOptionsName !== 'google') { + googleOptions = await parseProviderOptions({ + provider: 'google', + providerOptions, + schema: googleGenerativeAIProviderOptions, + }); + } + // Add warning if Vertex rag tools are used with a non-Vertex Google provider if ( tools?.some( @@ -182,13 +193,14 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { labels: googleOptions?.labels, }, warnings: [...warnings, ...toolWarnings], + providerOptionsName, }; } async doGenerate( options: Parameters[0], ): Promise>> { - const { args, warnings } = await this.getArgs(options); + const { args, warnings, providerOptionsName } = await this.getArgs(options); const mergedHeaders = combineHeaders( await resolve(this.config.headers), @@ -253,7 +265,11 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { type: part.thought === true ? 'reasoning' : 'text', text: part.text, providerMetadata: part.thoughtSignature - ? { google: { thoughtSignature: part.thoughtSignature } } + ? { + [providerOptionsName]: { + thoughtSignature: part.thoughtSignature, + }, + } : undefined, }); } else if ('functionCall' in part) { @@ -263,7 +279,11 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { toolName: part.functionCall.name, input: JSON.stringify(part.functionCall.args), providerMetadata: part.thoughtSignature - ? { google: { thoughtSignature: part.thoughtSignature } } + ? { + [providerOptionsName]: { + thoughtSignature: part.thoughtSignature, + }, + } : undefined, }); } else if ('inlineData' in part) { @@ -272,7 +292,11 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { data: part.inlineData.data, mediaType: part.inlineData.mimeType, providerMetadata: part.thoughtSignature - ? { google: { thoughtSignature: part.thoughtSignature } } + ? { + [providerOptionsName]: { + thoughtSignature: part.thoughtSignature, + }, + } : undefined, }); } @@ -296,7 +320,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { usage: convertGoogleGenerativeAIUsage(usageMetadata), warnings, providerMetadata: { - google: { + [providerOptionsName]: { promptFeedback: response.promptFeedback ?? null, groundingMetadata: candidate.groundingMetadata ?? null, urlContextMetadata: candidate.urlContextMetadata ?? null, @@ -316,7 +340,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { async doStream( options: Parameters[0], ): Promise>> { - const { args, warnings } = await this.getArgs(options); + const { args, warnings, providerOptionsName } = await this.getArgs(options); const headers = combineHeaders( await resolve(this.config.headers), @@ -466,7 +490,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { id: currentReasoningBlockId, providerMetadata: part.thoughtSignature ? { - google: { + [providerOptionsName]: { thoughtSignature: part.thoughtSignature, }, } @@ -480,7 +504,9 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { delta: part.text, providerMetadata: part.thoughtSignature ? { - google: { thoughtSignature: part.thoughtSignature }, + [providerOptionsName]: { + thoughtSignature: part.thoughtSignature, + }, } : undefined, }); @@ -502,7 +528,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { id: currentTextBlockId, providerMetadata: part.thoughtSignature ? { - google: { + [providerOptionsName]: { thoughtSignature: part.thoughtSignature, }, } @@ -516,7 +542,9 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { delta: part.text, providerMetadata: part.thoughtSignature ? { - google: { thoughtSignature: part.thoughtSignature }, + [providerOptionsName]: { + thoughtSignature: part.thoughtSignature, + }, } : undefined, }); @@ -534,6 +562,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { const toolCallDeltas = getToolCallsFromParts({ parts: content.parts, generateId, + providerOptionsName, }); if (toolCallDeltas != null) { @@ -578,7 +607,7 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { }); providerMetadata = { - google: { + [providerOptionsName]: { promptFeedback: value.promptFeedback ?? null, groundingMetadata: candidate.groundingMetadata ?? null, urlContextMetadata: candidate.urlContextMetadata ?? null, @@ -586,7 +615,12 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { }, }; if (usageMetadata != null) { - providerMetadata.google.usageMetadata = usageMetadata; + ( + providerMetadata[providerOptionsName] as Record< + string, + unknown + > + ).usageMetadata = usageMetadata; } } }, @@ -624,9 +658,11 @@ export class GoogleGenerativeAILanguageModel implements LanguageModelV3 { function getToolCallsFromParts({ parts, generateId, + providerOptionsName, }: { parts: ContentSchema['parts']; generateId: () => string; + providerOptionsName: string; }) { const functionCallParts = parts?.filter( part => 'functionCall' in part, @@ -645,7 +681,11 @@ function getToolCallsFromParts({ toolName: part.functionCall.name, args: JSON.stringify(part.functionCall.args), providerMetadata: part.thoughtSignature - ? { google: { thoughtSignature: part.thoughtSignature } } + ? { + [providerOptionsName]: { + thoughtSignature: part.thoughtSignature, + }, + } : undefined, })); }