diff --git a/common/api-review/vertexai.api.md b/common/api-review/vertexai.api.md index ef7d0b4f300..ca8bcf17662 100644 --- a/common/api-review/vertexai.api.md +++ b/common/api-review/vertexai.api.md @@ -42,7 +42,8 @@ export class BooleanSchema extends Schema { // @public export class ChatSession { // Warning: (ae-forgotten-export) The symbol "ApiSettings" needs to be exported by the entry point index.d.ts - constructor(apiSettings: ApiSettings, model: string, params?: StartChatParams | undefined, requestOptions?: RequestOptions | undefined); + // Warning: (ae-forgotten-export) The symbol "ChromeAdapter" needs to be exported by the entry point index.d.ts + constructor(apiSettings: ApiSettings, model: string, chromeAdapter: ChromeAdapter, params?: StartChatParams | undefined, requestOptions?: RequestOptions | undefined); getHistory(): Promise; // (undocumented) model: string; @@ -324,7 +325,7 @@ export interface GenerativeContentBlob { // @public export class GenerativeModel extends VertexAIModel { - constructor(vertexAI: VertexAI, modelParams: ModelParams, requestOptions?: RequestOptions); + constructor(vertexAI: VertexAI, modelParams: ModelParams, chromeAdapter: ChromeAdapter, requestOptions?: RequestOptions); countTokens(request: CountTokensRequest | string | Array): Promise; static DEFAULT_HYBRID_IN_CLOUD_MODEL: string; generateContent(request: GenerateContentRequest | string | Array): Promise; diff --git a/packages/vertexai/src/api.ts b/packages/vertexai/src/api.ts index 64bf4231edb..512edfe033b 100644 --- a/packages/vertexai/src/api.ts +++ b/packages/vertexai/src/api.ts @@ -30,6 +30,7 @@ import { } from './types'; import { VertexAIError } from './errors'; import { VertexAIModel, GenerativeModel, ImagenModel } from './models'; +import { ChromeAdapter } from './methods/chrome-adapter'; export { ChatSession } from './methods/chat-session'; export * from './requests/schema-builder'; @@ -91,7 +92,16 @@ export function getGenerativeModel( `Must provide a model name. Example: getGenerativeModel({ model: 'my-model-name' })` ); } - return new GenerativeModel(vertexAI, inCloudParams, requestOptions); + return new GenerativeModel( + vertexAI, + inCloudParams, + new ChromeAdapter( + window.ai as AI, + hybridParams.mode, + hybridParams.onDeviceParams + ), + requestOptions + ); } /** diff --git a/packages/vertexai/src/methods/chat-session.test.ts b/packages/vertexai/src/methods/chat-session.test.ts index bd389a3d778..64f77f740f0 100644 --- a/packages/vertexai/src/methods/chat-session.test.ts +++ b/packages/vertexai/src/methods/chat-session.test.ts @@ -23,6 +23,7 @@ import * as generateContentMethods from './generate-content'; import { GenerateContentStreamResult } from '../types'; import { ChatSession } from './chat-session'; import { ApiSettings } from '../types/internal'; +import { ChromeAdapter } from './chrome-adapter'; use(sinonChai); use(chaiAsPromised); @@ -44,7 +45,11 @@ describe('ChatSession', () => { generateContentMethods, 'generateContent' ).rejects('generateContent failed'); - const chatSession = new ChatSession(fakeApiSettings, 'a-model'); + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + new ChromeAdapter() + ); await expect(chatSession.sendMessage('hello')).to.be.rejected; expect(generateContentStub).to.be.calledWith( fakeApiSettings, @@ -61,7 +66,11 @@ describe('ChatSession', () => { generateContentMethods, 'generateContentStream' ).rejects('generateContentStream failed'); - const chatSession = new ChatSession(fakeApiSettings, 'a-model'); + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + new ChromeAdapter() + ); await expect(chatSession.sendMessageStream('hello')).to.be.rejected; expect(generateContentStreamStub).to.be.calledWith( fakeApiSettings, @@ -80,7 +89,11 @@ describe('ChatSession', () => { generateContentMethods, 'generateContentStream' ).resolves({} as unknown as GenerateContentStreamResult); - const chatSession = new ChatSession(fakeApiSettings, 'a-model'); + const chatSession = new ChatSession( + fakeApiSettings, + 'a-model', + new ChromeAdapter() + ); await chatSession.sendMessageStream('hello'); expect(generateContentStreamStub).to.be.calledWith( fakeApiSettings, diff --git a/packages/vertexai/src/methods/chat-session.ts b/packages/vertexai/src/methods/chat-session.ts index dd22b29a7c8..55c7700c156 100644 --- a/packages/vertexai/src/methods/chat-session.ts +++ b/packages/vertexai/src/methods/chat-session.ts @@ -30,6 +30,7 @@ import { validateChatHistory } from './chat-session-helpers'; import { generateContent, generateContentStream } from './generate-content'; import { ApiSettings } from '../types/internal'; import { logger } from '../logger'; +import { ChromeAdapter } from './chrome-adapter'; /** * Do not log a message for this error. @@ -50,6 +51,7 @@ export class ChatSession { constructor( apiSettings: ApiSettings, public model: string, + private chromeAdapter: ChromeAdapter, public params?: StartChatParams, public requestOptions?: RequestOptions ) { @@ -95,6 +97,7 @@ export class ChatSession { this._apiSettings, this.model, generateContentRequest, + this.chromeAdapter, this.requestOptions ) ) diff --git a/packages/vertexai/src/methods/chrome-adapter.ts b/packages/vertexai/src/methods/chrome-adapter.ts new file mode 100644 index 00000000000..f66e02f6711 --- /dev/null +++ b/packages/vertexai/src/methods/chrome-adapter.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { + EnhancedGenerateContentResponse, + GenerateContentRequest, + InferenceMode +} from '../types'; + +/** + * Defines an inference "backend" that uses Chrome's on-device model, + * and encapsulates logic for detecting when on-device is possible. + */ +export class ChromeAdapter { + constructor( + private aiProvider?: AI, + private mode?: InferenceMode, + private onDeviceParams?: AILanguageModelCreateOptionsWithSystemPrompt + ) {} + // eslint-disable-next-line @typescript-eslint/no-unused-vars + async isAvailable(request: GenerateContentRequest): Promise { + return false; + } + async generateContentOnDevice( + // eslint-disable-next-line @typescript-eslint/no-unused-vars + request: GenerateContentRequest + ): Promise { + return { + text: () => '', + functionCalls: () => undefined + }; + } +} diff --git a/packages/vertexai/src/methods/generate-content.test.ts b/packages/vertexai/src/methods/generate-content.test.ts index 426bd5176db..ec825421709 100644 --- a/packages/vertexai/src/methods/generate-content.test.ts +++ b/packages/vertexai/src/methods/generate-content.test.ts @@ -30,6 +30,7 @@ import { } from '../types'; import { ApiSettings } from '../types/internal'; import { Task } from '../requests/request'; +import { ChromeAdapter } from './chrome-adapter'; use(sinonChai); use(chaiAsPromised); @@ -69,7 +70,8 @@ describe('generateContent()', () => { const result = await generateContent( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.response.text()).to.include('Mountain View, California'); expect(makeRequestStub).to.be.calledWith( @@ -91,7 +93,8 @@ describe('generateContent()', () => { const result = await generateContent( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.response.text()).to.include('Use Freshly Ground Coffee'); expect(result.response.text()).to.include('30 minutes of brewing'); @@ -113,7 +116,8 @@ describe('generateContent()', () => { const result = await generateContent( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.response.usageMetadata?.totalTokenCount).to.equal(1913); expect(result.response.usageMetadata?.candidatesTokenCount).to.equal(76); @@ -145,7 +149,8 @@ describe('generateContent()', () => { const result = await generateContent( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.response.text()).to.include( 'Some information cited from an external source' @@ -171,7 +176,8 @@ describe('generateContent()', () => { const result = await generateContent( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.response.text).to.throw('SAFETY'); expect(makeRequestStub).to.be.calledWith( @@ -192,7 +198,8 @@ describe('generateContent()', () => { const result = await generateContent( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.response.text).to.throw('SAFETY'); expect(makeRequestStub).to.be.calledWith( @@ -211,7 +218,8 @@ describe('generateContent()', () => { const result = await generateContent( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.response.text()).to.equal(''); expect(makeRequestStub).to.be.calledWith( @@ -232,7 +240,8 @@ describe('generateContent()', () => { const result = await generateContent( fakeApiSettings, 'model', - fakeRequestParams + fakeRequestParams, + new ChromeAdapter() ); expect(result.response.text()).to.include('Some text'); expect(makeRequestStub).to.be.calledWith( @@ -251,7 +260,12 @@ describe('generateContent()', () => { json: mockResponse.json } as Response); await expect( - generateContent(fakeApiSettings, 'model', fakeRequestParams) + generateContent( + fakeApiSettings, + 'model', + fakeRequestParams, + new ChromeAdapter() + ) ).to.be.rejectedWith(/400.*invalid argument/); expect(mockFetch).to.be.called; }); @@ -265,10 +279,36 @@ describe('generateContent()', () => { json: mockResponse.json } as Response); await expect( - generateContent(fakeApiSettings, 'model', fakeRequestParams) + generateContent( + fakeApiSettings, + 'model', + fakeRequestParams, + new ChromeAdapter() + ) ).to.be.rejectedWith( /firebasevertexai\.googleapis[\s\S]*my-project[\s\S]*api-not-enabled/ ); expect(mockFetch).to.be.called; }); + it('on-device', async () => { + const expectedText = 'hi'; + const chromeAdapter = new ChromeAdapter(); + const mockIsAvailable = stub(chromeAdapter, 'isAvailable').resolves(true); + const mockGenerateContent = stub( + chromeAdapter, + 'generateContentOnDevice' + ).resolves({ + text: () => expectedText, + functionCalls: () => undefined + }); + const result = await generateContent( + fakeApiSettings, + 'model', + fakeRequestParams, + chromeAdapter + ); + expect(result.response.text()).to.equal(expectedText); + expect(mockIsAvailable).to.be.called; + expect(mockGenerateContent).to.be.calledWith(fakeRequestParams); + }); }); diff --git a/packages/vertexai/src/methods/generate-content.ts b/packages/vertexai/src/methods/generate-content.ts index 0944b38016a..63745c47fae 100644 --- a/packages/vertexai/src/methods/generate-content.ts +++ b/packages/vertexai/src/methods/generate-content.ts @@ -16,6 +16,7 @@ */ import { + EnhancedGenerateContentResponse, GenerateContentRequest, GenerateContentResponse, GenerateContentResult, @@ -26,6 +27,7 @@ import { Task, makeRequest } from '../requests/request'; import { createEnhancedContentResponse } from '../requests/response-helpers'; import { processStream } from '../requests/stream-reader'; import { ApiSettings } from '../types/internal'; +import { ChromeAdapter } from './chrome-adapter'; export async function generateContentStream( apiSettings: ApiSettings, @@ -44,12 +46,12 @@ export async function generateContentStream( return processStream(response); } -export async function generateContent( +async function generateContentOnCloud( apiSettings: ApiSettings, model: string, params: GenerateContentRequest, requestOptions?: RequestOptions -): Promise { +): Promise { const response = await makeRequest( model, Task.GENERATE_CONTENT, @@ -60,6 +62,27 @@ export async function generateContent( ); const responseJson: GenerateContentResponse = await response.json(); const enhancedResponse = createEnhancedContentResponse(responseJson); + return enhancedResponse; +} + +export async function generateContent( + apiSettings: ApiSettings, + model: string, + params: GenerateContentRequest, + chromeAdapter: ChromeAdapter, + requestOptions?: RequestOptions +): Promise { + let enhancedResponse; + if (await chromeAdapter.isAvailable(params)) { + enhancedResponse = await chromeAdapter.generateContentOnDevice(params); + } else { + enhancedResponse = await generateContentOnCloud( + apiSettings, + model, + params, + requestOptions + ); + } return { response: enhancedResponse }; diff --git a/packages/vertexai/src/models/generative-model.test.ts b/packages/vertexai/src/models/generative-model.test.ts index 26dff4e04c6..32ddcca41d1 100644 --- a/packages/vertexai/src/models/generative-model.test.ts +++ b/packages/vertexai/src/models/generative-model.test.ts @@ -21,6 +21,7 @@ import * as request from '../requests/request'; import { match, restore, stub } from 'sinon'; import { getMockResponse } from '../../test-utils/mock-response'; import sinonChai from 'sinon-chai'; +import { ChromeAdapter } from '../methods/chrome-adapter'; use(sinonChai); @@ -39,21 +40,27 @@ const fakeVertexAI: VertexAI = { describe('GenerativeModel', () => { it('passes params through to generateContent', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - tools: [ - { - functionDeclarations: [ - { - name: 'myfunc', - description: 'mydesc' - } - ] - } - ], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] } - }); + const genModel = new GenerativeModel( + fakeVertexAI, + { + model: 'my-model', + tools: [ + { + functionDeclarations: [ + { + name: 'myfunc', + description: 'mydesc' + } + ] + } + ], + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.NONE } + }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] } + }, + new ChromeAdapter() + ); expect(genModel.tools?.length).to.equal(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( FunctionCallingMode.NONE @@ -83,10 +90,14 @@ describe('GenerativeModel', () => { restore(); }); it('passes text-only systemInstruction through to generateContent', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - systemInstruction: 'be friendly' - }); + const genModel = new GenerativeModel( + fakeVertexAI, + { + model: 'my-model', + systemInstruction: 'be friendly' + }, + new ChromeAdapter() + ); expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); const mockResponse = getMockResponse( 'unary-success-basic-reply-short.json' @@ -108,21 +119,27 @@ describe('GenerativeModel', () => { restore(); }); it('generateContent overrides model values', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - tools: [ - { - functionDeclarations: [ - { - name: 'myfunc', - description: 'mydesc' - } - ] - } - ], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] } - }); + const genModel = new GenerativeModel( + fakeVertexAI, + { + model: 'my-model', + tools: [ + { + functionDeclarations: [ + { + name: 'myfunc', + description: 'mydesc' + } + ] + } + ], + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.NONE } + }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] } + }, + new ChromeAdapter() + ); expect(genModel.tools?.length).to.equal(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( FunctionCallingMode.NONE @@ -163,14 +180,20 @@ describe('GenerativeModel', () => { restore(); }); it('passes params through to chat.sendMessage', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - tools: [ - { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } - ], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] } - }); + const genModel = new GenerativeModel( + fakeVertexAI, + { + model: 'my-model', + tools: [ + { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } + ], + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.NONE } + }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] } + }, + new ChromeAdapter() + ); expect(genModel.tools?.length).to.equal(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( FunctionCallingMode.NONE @@ -200,10 +223,14 @@ describe('GenerativeModel', () => { restore(); }); it('passes text-only systemInstruction through to chat.sendMessage', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - systemInstruction: 'be friendly' - }); + const genModel = new GenerativeModel( + fakeVertexAI, + { + model: 'my-model', + systemInstruction: 'be friendly' + }, + new ChromeAdapter() + ); expect(genModel.systemInstruction?.parts[0].text).to.equal('be friendly'); const mockResponse = getMockResponse( 'unary-success-basic-reply-short.json' @@ -225,14 +252,20 @@ describe('GenerativeModel', () => { restore(); }); it('startChat overrides model values', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { - model: 'my-model', - tools: [ - { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } - ], - toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } }, - systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] } - }); + const genModel = new GenerativeModel( + fakeVertexAI, + { + model: 'my-model', + tools: [ + { functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] } + ], + toolConfig: { + functionCallingConfig: { mode: FunctionCallingMode.NONE } + }, + systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] } + }, + new ChromeAdapter() + ); expect(genModel.tools?.length).to.equal(1); expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal( FunctionCallingMode.NONE @@ -276,7 +309,11 @@ describe('GenerativeModel', () => { restore(); }); it('calls countTokens', async () => { - const genModel = new GenerativeModel(fakeVertexAI, { model: 'my-model' }); + const genModel = new GenerativeModel( + fakeVertexAI, + { model: 'my-model' }, + new ChromeAdapter() + ); const mockResponse = getMockResponse('unary-success-total-tokens.json'); const makeRequestStub = stub(request, 'makeRequest').resolves( mockResponse as Response diff --git a/packages/vertexai/src/models/generative-model.ts b/packages/vertexai/src/models/generative-model.ts index 76e76bfb8d2..f8f699e43eb 100644 --- a/packages/vertexai/src/models/generative-model.ts +++ b/packages/vertexai/src/models/generative-model.ts @@ -43,6 +43,7 @@ import { } from '../requests/request-helpers'; import { VertexAI } from '../public-types'; import { VertexAIModel } from './vertexai-model'; +import { ChromeAdapter } from '../methods/chrome-adapter'; /** * Class for generative model APIs. @@ -63,6 +64,7 @@ export class GenerativeModel extends VertexAIModel { constructor( vertexAI: VertexAI, modelParams: ModelParams, + private chromeAdapter: ChromeAdapter, requestOptions?: RequestOptions ) { super(vertexAI, modelParams.model); @@ -95,6 +97,7 @@ export class GenerativeModel extends VertexAIModel { systemInstruction: this.systemInstruction, ...formattedParams }, + this.chromeAdapter, this.requestOptions ); } @@ -132,6 +135,7 @@ export class GenerativeModel extends VertexAIModel { return new ChatSession( this._apiSettings, this.model, + this.chromeAdapter, { tools: this.tools, toolConfig: this.toolConfig,