From d68ed2f6d59fa5c4b769f301cae0b8c876e7e2b5 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 12 Dec 2024 17:22:42 -0500 Subject: [PATCH] feat(js/plugins/googleai): adds gemini() function for unspecified model support (#1467) Co-authored-by: Michael Bleigh --- js/plugins/googleai/package.json | 2 +- js/plugins/googleai/src/gemini.ts | 77 +++++++++++-- js/plugins/googleai/src/index.ts | 30 ++++- js/plugins/googleai/tests/gemini_test.ts | 136 ++++++++++++++++++++++- js/pnpm-lock.yaml | 9 ++ js/testapps/flow-simple-ai/package.json | 7 +- 6 files changed, 247 insertions(+), 14 deletions(-) diff --git a/js/plugins/googleai/package.json b/js/plugins/googleai/package.json index 2d8738e11..31513f012 100644 --- a/js/plugins/googleai/package.json +++ b/js/plugins/googleai/package.json @@ -21,7 +21,7 @@ "build:clean": "rimraf ./lib", "build": "npm-run-all build:clean check compile", "build:watch": "tsup-node --watch", - "test": "tsx --test ./tests/**/*_test.ts" + "test": "tsx --test ./tests/*_test.ts ./tests/**/*_test.ts" }, "repository": { "type": "git", diff --git a/js/plugins/googleai/src/gemini.ts b/js/plugins/googleai/src/gemini.ts index 67b1932d6..ccc381063 100644 --- a/js/plugins/googleai/src/gemini.ts +++ b/js/plugins/googleai/src/gemini.ts @@ -92,6 +92,7 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({ }) .optional(), }); +export type GeminiConfig = z.infer; export const gemini10Pro = modelRef({ name: 'googleai/gemini-1.0-pro', @@ -189,13 +190,73 @@ export const SUPPORTED_V15_MODELS = { 'gemini-2.0-flash-exp': gemini20FlashExp, }; -export const SUPPORTED_GEMINI_MODELS: Record< - string, - ModelReference -> = { +export const GENERIC_GEMINI_MODEL = modelRef({ + name: 'googleai/gemini', + configSchema: GeminiConfigSchema, + info: { + label: 'Google Gemini', + supports: { + multiturn: true, + media: true, + tools: true, + systemRole: true, + }, + }, +}); + +export const SUPPORTED_GEMINI_MODELS = { ...SUPPORTED_V1_MODELS, ...SUPPORTED_V15_MODELS, -}; +} as const; + +function longestMatchingPrefix(version: string, potentialMatches: string[]) { + return potentialMatches + .filter((p) => version.startsWith(p)) + .reduce( + (longest, current) => + current.length > longest.length ? current : longest, + '' + ); +} +export type GeminiVersionString = + | keyof typeof SUPPORTED_GEMINI_MODELS + | (string & {}); + +export function gemini( + version: GeminiVersionString, + options: GeminiConfig = {} +): ModelReference { + const nearestModel = nearestGeminiModelRef(version); + return modelRef({ + name: `googleai/${version}`, + config: options, + configSchema: GeminiConfigSchema, + info: { + ...nearestModel.info, + // If exact suffix match for a known model, use its label, otherwise create a new label + label: nearestModel.name.endsWith(version) + ? nearestModel.info?.label + : `Google AI - ${version}`, + }, + }); +} + +function nearestGeminiModelRef( + version: GeminiVersionString, + options: GeminiConfig = {} +): ModelReference { + const matchingKey = longestMatchingPrefix( + version, + Object.keys(SUPPORTED_GEMINI_MODELS) + ); + if (matchingKey) { + return SUPPORTED_GEMINI_MODELS[matchingKey].withConfig({ + ...options, + version, + }); + } + return GENERIC_GEMINI_MODEL.withConfig({ ...options, version }); +} function toGeminiRole( role: MessageData['role'], @@ -501,7 +562,7 @@ export function defineGoogleAIModel( apiVersion?: string, baseUrl?: string, info?: ModelInfo, - defaultConfig?: z.infer + defaultConfig?: GeminiConfig ): ModelAction { if (!apiKey) { apiKey = process.env.GOOGLE_GENAI_API_KEY || process.env.GOOGLE_API_KEY; @@ -519,7 +580,7 @@ export function defineGoogleAIModel( const model: ModelReference = SUPPORTED_GEMINI_MODELS[name] ?? modelRef({ - name, + name: `googleai/${apiModelName}`, info: { label: `Google AI - ${apiModelName}`, supports: { @@ -648,7 +709,7 @@ export function defineGoogleAIModel( let chatRequest: StartChatParams = { systemInstruction, generationConfig, - tools, + tools: tools.length ? tools : undefined, toolConfig, history: messages .slice(0, -1) diff --git a/js/plugins/googleai/src/index.ts b/js/plugins/googleai/src/index.ts index be1de65f1..b9e6ea582 100644 --- a/js/plugins/googleai/src/index.ts +++ b/js/plugins/googleai/src/index.ts @@ -14,7 +14,7 @@ * limitations under the License. */ -import { Genkit } from 'genkit'; +import { Genkit, ModelReference } from 'genkit'; import { GenkitPlugin, genkitPlugin } from 'genkit/plugin'; import { SUPPORTED_MODELS as EMBEDDER_MODELS, @@ -22,9 +22,11 @@ import { textEmbeddingGecko001, } from './embedder.js'; import { + GeminiConfigSchema, SUPPORTED_V15_MODELS, SUPPORTED_V1_MODELS, defineGoogleAIModel, + gemini, gemini10Pro, gemini15Flash, gemini15Flash8b, @@ -32,6 +34,7 @@ import { gemini20FlashExp, } from './gemini.js'; export { + gemini, gemini10Pro, gemini15Flash, gemini15Flash8b, @@ -44,6 +47,7 @@ export interface PluginOptions { apiKey?: string; apiVersion?: string | string[]; baseUrl?: string; + models?: (ModelReference | string)[]; } export function googleAI(options?: PluginOptions): GenkitPlugin { @@ -57,6 +61,7 @@ export function googleAI(options?: PluginOptions): GenkitPlugin { apiVersions = [options?.apiVersion]; } } + if (apiVersions.includes('v1beta')) { Object.keys(SUPPORTED_V15_MODELS).forEach((name) => defineGoogleAIModel( @@ -91,6 +96,29 @@ export function googleAI(options?: PluginOptions): GenkitPlugin { defineGoogleAIEmbedder(ai, name, { apiKey: options?.apiKey }) ); } + + if (options?.models) { + for (const modelOrRef of options?.models) { + const modelName = + typeof modelOrRef === 'string' + ? modelOrRef + : // strip out the `googleai/` prefix + modelOrRef.name.split('/')[1]; + const modelRef = + typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef; + defineGoogleAIModel( + ai, + modelName, + options?.apiKey, + undefined, + options?.baseUrl, + { + ...modelRef.info, + label: `Google AI - ${modelName}`, + } + ); + } + } }); } diff --git a/js/plugins/googleai/tests/gemini_test.ts b/js/plugins/googleai/tests/gemini_test.ts index 8f5549a81..6ae5da39f 100644 --- a/js/plugins/googleai/tests/gemini_test.ts +++ b/js/plugins/googleai/tests/gemini_test.ts @@ -15,15 +15,21 @@ */ import { GenerateContentCandidate } from '@google/generative-ai'; -import { MessageData } from 'genkit/model'; +import { genkit } from 'genkit'; +import { MessageData, ModelInfo } from 'genkit/model'; import assert from 'node:assert'; -import { describe, it } from 'node:test'; +import { afterEach, beforeEach, describe, it } from 'node:test'; import { + GENERIC_GEMINI_MODEL, cleanSchema, fromGeminiCandidate, + gemini, + gemini15Flash, + gemini15Pro, toGeminiMessage, toGeminiSystemInstruction, } from '../src/gemini.js'; +import { googleAI } from '../src/index.js'; describe('toGeminiMessages', () => { const testCases = [ @@ -378,3 +384,129 @@ describe('cleanSchema', () => { }); }); }); + +describe('plugin', () => { + it('should init the plugin without requiring the api key', async () => { + const ai = genkit({ + plugins: [googleAI()], + }); + + assert.ok(ai); + }); + + describe('plugin', () => { + beforeEach(() => { + process.env.GOOGLE_GENAI_API_KEY = 'testApiKey'; + }); + afterEach(() => { + delete process.env.GOOGLE_GENAI_API_KEY; + }); + + it('should pre-register a few flagship models', async () => { + const ai = genkit({ + plugins: [googleAI()], + }); + + assert.ok(await ai.registry.lookupAction(`/model/${gemini15Flash.name}`)); + assert.ok(await ai.registry.lookupAction(`/model/${gemini15Pro.name}`)); + }); + + it('allow referencing models using `gemini` helper', async () => { + const ai = genkit({ + plugins: [googleAI()], + }); + + const pro = await ai.registry.lookupAction( + `/model/${gemini('gemini-1.5-pro').name}` + ); + assert.ok(pro); + assert.strictEqual(pro.__action.name, 'googleai/gemini-1.5-pro'); + const flash = await ai.registry.lookupAction( + `/model/${gemini('gemini-1.5-flash').name}` + ); + assert.ok(flash); + assert.strictEqual(flash.__action.name, 'googleai/gemini-1.5-flash'); + }); + + it('references explicitly registered models', async () => { + const flash002Ref = gemini('gemini-1.5-flash-002'); + const ai = genkit({ + plugins: [ + googleAI({ + models: ['gemini-1.5-pro-002', flash002Ref, 'gemini-4.0-banana'], + }), + ], + }); + + const pro002Ref = gemini('gemini-1.5-pro-002'); + assert.strictEqual(pro002Ref.name, 'googleai/gemini-1.5-pro-002'); + assertEqualModelInfo( + pro002Ref.info!, + 'Google AI - gemini-1.5-pro-002', + gemini15Pro.info! + ); + const pro002 = await ai.registry.lookupAction(`/model/${pro002Ref.name}`); + assert.ok(pro002); + assert.strictEqual(pro002.__action.name, 'googleai/gemini-1.5-pro-002'); + assertEqualModelInfo( + pro002.__action.metadata?.model, + 'Google AI - gemini-1.5-pro-002', + gemini15Pro.info! + ); + + assert.strictEqual(flash002Ref.name, 'googleai/gemini-1.5-flash-002'); + assertEqualModelInfo( + flash002Ref.info!, + 'Google AI - gemini-1.5-flash-002', + gemini15Flash.info! + ); + const flash002 = await ai.registry.lookupAction( + `/model/${flash002Ref.name}` + ); + assert.ok(flash002); + assert.strictEqual( + flash002.__action.name, + 'googleai/gemini-1.5-flash-002' + ); + assertEqualModelInfo( + flash002.__action.metadata?.model, + 'Google AI - gemini-1.5-flash-002', + gemini15Flash.info! + ); + + const bananaRef = gemini('gemini-4.0-banana'); + assert.strictEqual(bananaRef.name, 'googleai/gemini-4.0-banana'); + assertEqualModelInfo( + bananaRef.info!, + 'Google AI - gemini-4.0-banana', + GENERIC_GEMINI_MODEL.info! // <---- generic model fallback + ); + const banana = await ai.registry.lookupAction(`/model/${bananaRef.name}`); + assert.ok(banana); + assert.strictEqual(banana.__action.name, 'googleai/gemini-4.0-banana'); + assertEqualModelInfo( + banana.__action.metadata?.model, + 'Google AI - gemini-4.0-banana', + GENERIC_GEMINI_MODEL.info! // <---- generic model fallback + ); + + // this one is not registered + const flash003Ref = gemini('gemini-1.5-flash-003'); + assert.strictEqual(flash003Ref.name, 'googleai/gemini-1.5-flash-003'); + const flash003 = await ai.registry.lookupAction( + `/model/${flash003Ref.name}` + ); + assert.ok(flash003 === undefined); + }); + }); +}); + +function assertEqualModelInfo( + modelAction: ModelInfo, + expectedLabel: string, + expectedInfo: ModelInfo +) { + assert.strictEqual(modelAction.label, expectedLabel); + assert.deepStrictEqual(modelAction.supports, expectedInfo.supports); + assert.deepStrictEqual(modelAction.versions, expectedInfo.versions); +} diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 64b2cbe6d..f2adc7290 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -1172,6 +1172,12 @@ importers: '@opentelemetry/sdk-trace-base': specifier: ^1.25.0 version: 1.25.1(@opentelemetry/api@1.9.0) + body-parser: + specifier: ^1.20.3 + version: 1.20.3 + express: + specifier: ^4.21.0 + version: 4.21.0 firebase-admin: specifier: '>=12.2' version: 12.3.1(encoding@0.1.13) @@ -1185,6 +1191,9 @@ importers: rimraf: specifier: ^6.0.1 version: 6.0.1 + tsx: + specifier: ^4.19.2 + version: 4.19.2 typescript: specifier: ^5.3.3 version: 5.4.5 diff --git a/js/testapps/flow-simple-ai/package.json b/js/testapps/flow-simple-ai/package.json index fd0f02e9e..6ca06aad8 100644 --- a/js/testapps/flow-simple-ai/package.json +++ b/js/testapps/flow-simple-ai/package.json @@ -4,7 +4,7 @@ "description": "", "main": "lib/index.js", "scripts": { - "start": "node lib/index.js", + "start": "pnpm exec genkit start -- pnpm exec tsx --watch src/index.ts", "compile": "tsc", "build": "pnpm build:clean && pnpm compile", "build:clean": "rimraf ./lib", @@ -15,18 +15,21 @@ "author": "", "license": "ISC", "dependencies": { - "genkit": "workspace:*", "@genkit-ai/firebase": "workspace:*", "@genkit-ai/google-cloud": "workspace:*", "@genkit-ai/googleai": "workspace:*", "@genkit-ai/vertexai": "workspace:*", "@google/generative-ai": "^0.15.0", "@opentelemetry/sdk-trace-base": "^1.25.0", + "body-parser": "^1.20.3", + "express": "^4.21.0", "firebase-admin": ">=12.2", + "genkit": "workspace:*", "partial-json": "^0.1.7" }, "devDependencies": { "rimraf": "^6.0.1", + "tsx": "^4.19.2", "typescript": "^5.3.3" } }