Skip to content

Commit

Permalink
feat(js/plugins/googleai): adds gemini() function for unspecified mod…
Browse files Browse the repository at this point in the history
…el support (#1467)

Co-authored-by: Michael Bleigh <[email protected]>
  • Loading branch information
pavelgj and mbleigh authored Dec 12, 2024
1 parent a2bf8a3 commit d68ed2f
Show file tree
Hide file tree
Showing 6 changed files with 247 additions and 14 deletions.
2 changes: 1 addition & 1 deletion js/plugins/googleai/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"build:clean": "rimraf ./lib",
"build": "npm-run-all build:clean check compile",
"build:watch": "tsup-node --watch",
"test": "tsx --test ./tests/**/*_test.ts"
"test": "tsx --test ./tests/*_test.ts ./tests/**/*_test.ts"
},
"repository": {
"type": "git",
Expand Down
77 changes: 69 additions & 8 deletions js/plugins/googleai/src/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ export const GeminiConfigSchema = GenerationCommonConfigSchema.extend({
})
.optional(),
});
export type GeminiConfig = z.infer<typeof GeminiConfigSchema>;

export const gemini10Pro = modelRef({
name: 'googleai/gemini-1.0-pro',
Expand Down Expand Up @@ -189,13 +190,73 @@ export const SUPPORTED_V15_MODELS = {
'gemini-2.0-flash-exp': gemini20FlashExp,
};

export const SUPPORTED_GEMINI_MODELS: Record<
string,
ModelReference<typeof GeminiConfigSchema>
> = {
export const GENERIC_GEMINI_MODEL = modelRef({
name: 'googleai/gemini',
configSchema: GeminiConfigSchema,
info: {
label: 'Google Gemini',
supports: {
multiturn: true,
media: true,
tools: true,
systemRole: true,
},
},
});

export const SUPPORTED_GEMINI_MODELS = {
...SUPPORTED_V1_MODELS,
...SUPPORTED_V15_MODELS,
};
} as const;

function longestMatchingPrefix(version: string, potentialMatches: string[]) {
return potentialMatches
.filter((p) => version.startsWith(p))
.reduce(
(longest, current) =>
current.length > longest.length ? current : longest,
''
);
}
export type GeminiVersionString =
| keyof typeof SUPPORTED_GEMINI_MODELS
| (string & {});

export function gemini(
version: GeminiVersionString,
options: GeminiConfig = {}
): ModelReference<typeof GeminiConfigSchema> {
const nearestModel = nearestGeminiModelRef(version);
return modelRef({
name: `googleai/${version}`,
config: options,
configSchema: GeminiConfigSchema,
info: {
...nearestModel.info,
// If exact suffix match for a known model, use its label, otherwise create a new label
label: nearestModel.name.endsWith(version)
? nearestModel.info?.label
: `Google AI - ${version}`,
},
});
}

function nearestGeminiModelRef(
version: GeminiVersionString,
options: GeminiConfig = {}
): ModelReference<typeof GeminiConfigSchema> {
const matchingKey = longestMatchingPrefix(
version,
Object.keys(SUPPORTED_GEMINI_MODELS)
);
if (matchingKey) {
return SUPPORTED_GEMINI_MODELS[matchingKey].withConfig({
...options,
version,
});
}
return GENERIC_GEMINI_MODEL.withConfig({ ...options, version });
}

function toGeminiRole(
role: MessageData['role'],
Expand Down Expand Up @@ -501,7 +562,7 @@ export function defineGoogleAIModel(
apiVersion?: string,
baseUrl?: string,
info?: ModelInfo,
defaultConfig?: z.infer<typeof GeminiConfigSchema>
defaultConfig?: GeminiConfig
): ModelAction {
if (!apiKey) {
apiKey = process.env.GOOGLE_GENAI_API_KEY || process.env.GOOGLE_API_KEY;
Expand All @@ -519,7 +580,7 @@ export function defineGoogleAIModel(
const model: ModelReference<z.ZodTypeAny> =
SUPPORTED_GEMINI_MODELS[name] ??
modelRef({
name,
name: `googleai/${apiModelName}`,
info: {
label: `Google AI - ${apiModelName}`,
supports: {
Expand Down Expand Up @@ -648,7 +709,7 @@ export function defineGoogleAIModel(
let chatRequest: StartChatParams = {
systemInstruction,
generationConfig,
tools,
tools: tools.length ? tools : undefined,
toolConfig,
history: messages
.slice(0, -1)
Expand Down
30 changes: 29 additions & 1 deletion js/plugins/googleai/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,27 @@
* limitations under the License.
*/

import { Genkit } from 'genkit';
import { Genkit, ModelReference } from 'genkit';
import { GenkitPlugin, genkitPlugin } from 'genkit/plugin';
import {
SUPPORTED_MODELS as EMBEDDER_MODELS,
defineGoogleAIEmbedder,
textEmbeddingGecko001,
} from './embedder.js';
import {
GeminiConfigSchema,
SUPPORTED_V15_MODELS,
SUPPORTED_V1_MODELS,
defineGoogleAIModel,
gemini,
gemini10Pro,
gemini15Flash,
gemini15Flash8b,
gemini15Pro,
gemini20FlashExp,
} from './gemini.js';
export {
gemini,
gemini10Pro,
gemini15Flash,
gemini15Flash8b,
Expand All @@ -44,6 +47,7 @@ export interface PluginOptions {
apiKey?: string;
apiVersion?: string | string[];
baseUrl?: string;
models?: (ModelReference<typeof GeminiConfigSchema> | string)[];
}

export function googleAI(options?: PluginOptions): GenkitPlugin {
Expand All @@ -57,6 +61,7 @@ export function googleAI(options?: PluginOptions): GenkitPlugin {
apiVersions = [options?.apiVersion];
}
}

if (apiVersions.includes('v1beta')) {
Object.keys(SUPPORTED_V15_MODELS).forEach((name) =>
defineGoogleAIModel(
Expand Down Expand Up @@ -91,6 +96,29 @@ export function googleAI(options?: PluginOptions): GenkitPlugin {
defineGoogleAIEmbedder(ai, name, { apiKey: options?.apiKey })
);
}

if (options?.models) {
for (const modelOrRef of options?.models) {
const modelName =
typeof modelOrRef === 'string'
? modelOrRef
: // strip out the `googleai/` prefix
modelOrRef.name.split('/')[1];
const modelRef =
typeof modelOrRef === 'string' ? gemini(modelOrRef) : modelOrRef;
defineGoogleAIModel(
ai,
modelName,
options?.apiKey,
undefined,
options?.baseUrl,
{
...modelRef.info,
label: `Google AI - ${modelName}`,
}
);
}
}
});
}

Expand Down
136 changes: 134 additions & 2 deletions js/plugins/googleai/tests/gemini_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,21 @@
*/

import { GenerateContentCandidate } from '@google/generative-ai';
import { MessageData } from 'genkit/model';
import { genkit } from 'genkit';
import { MessageData, ModelInfo } from 'genkit/model';
import assert from 'node:assert';
import { describe, it } from 'node:test';
import { afterEach, beforeEach, describe, it } from 'node:test';
import {
GENERIC_GEMINI_MODEL,
cleanSchema,
fromGeminiCandidate,
gemini,
gemini15Flash,
gemini15Pro,
toGeminiMessage,
toGeminiSystemInstruction,
} from '../src/gemini.js';
import { googleAI } from '../src/index.js';

describe('toGeminiMessages', () => {
const testCases = [
Expand Down Expand Up @@ -378,3 +384,129 @@ describe('cleanSchema', () => {
});
});
});

describe('plugin', () => {
it('should init the plugin without requiring the api key', async () => {
const ai = genkit({
plugins: [googleAI()],
});

assert.ok(ai);
});

describe('plugin', () => {
beforeEach(() => {
process.env.GOOGLE_GENAI_API_KEY = 'testApiKey';
});
afterEach(() => {
delete process.env.GOOGLE_GENAI_API_KEY;
});

it('should pre-register a few flagship models', async () => {
const ai = genkit({
plugins: [googleAI()],
});

assert.ok(await ai.registry.lookupAction(`/model/${gemini15Flash.name}`));
assert.ok(await ai.registry.lookupAction(`/model/${gemini15Pro.name}`));
});

it('allow referencing models using `gemini` helper', async () => {
const ai = genkit({
plugins: [googleAI()],
});

const pro = await ai.registry.lookupAction(
`/model/${gemini('gemini-1.5-pro').name}`
);
assert.ok(pro);
assert.strictEqual(pro.__action.name, 'googleai/gemini-1.5-pro');
const flash = await ai.registry.lookupAction(
`/model/${gemini('gemini-1.5-flash').name}`
);
assert.ok(flash);
assert.strictEqual(flash.__action.name, 'googleai/gemini-1.5-flash');
});

it('references explicitly registered models', async () => {
const flash002Ref = gemini('gemini-1.5-flash-002');
const ai = genkit({
plugins: [
googleAI({
models: ['gemini-1.5-pro-002', flash002Ref, 'gemini-4.0-banana'],
}),
],
});

const pro002Ref = gemini('gemini-1.5-pro-002');
assert.strictEqual(pro002Ref.name, 'googleai/gemini-1.5-pro-002');
assertEqualModelInfo(
pro002Ref.info!,
'Google AI - gemini-1.5-pro-002',
gemini15Pro.info!
);
const pro002 = await ai.registry.lookupAction(`/model/${pro002Ref.name}`);
assert.ok(pro002);
assert.strictEqual(pro002.__action.name, 'googleai/gemini-1.5-pro-002');
assertEqualModelInfo(
pro002.__action.metadata?.model,
'Google AI - gemini-1.5-pro-002',
gemini15Pro.info!
);

assert.strictEqual(flash002Ref.name, 'googleai/gemini-1.5-flash-002');
assertEqualModelInfo(
flash002Ref.info!,
'Google AI - gemini-1.5-flash-002',
gemini15Flash.info!
);
const flash002 = await ai.registry.lookupAction(
`/model/${flash002Ref.name}`
);
assert.ok(flash002);
assert.strictEqual(
flash002.__action.name,
'googleai/gemini-1.5-flash-002'
);
assertEqualModelInfo(
flash002.__action.metadata?.model,
'Google AI - gemini-1.5-flash-002',
gemini15Flash.info!
);

const bananaRef = gemini('gemini-4.0-banana');
assert.strictEqual(bananaRef.name, 'googleai/gemini-4.0-banana');
assertEqualModelInfo(
bananaRef.info!,
'Google AI - gemini-4.0-banana',
GENERIC_GEMINI_MODEL.info! // <---- generic model fallback
);
const banana = await ai.registry.lookupAction(`/model/${bananaRef.name}`);
assert.ok(banana);
assert.strictEqual(banana.__action.name, 'googleai/gemini-4.0-banana');
assertEqualModelInfo(
banana.__action.metadata?.model,
'Google AI - gemini-4.0-banana',
GENERIC_GEMINI_MODEL.info! // <---- generic model fallback
);

// this one is not registered
const flash003Ref = gemini('gemini-1.5-flash-003');
assert.strictEqual(flash003Ref.name, 'googleai/gemini-1.5-flash-003');
const flash003 = await ai.registry.lookupAction(
`/model/${flash003Ref.name}`
);
assert.ok(flash003 === undefined);
});
});
});

function assertEqualModelInfo(
modelAction: ModelInfo,
expectedLabel: string,
expectedInfo: ModelInfo
) {
assert.strictEqual(modelAction.label, expectedLabel);
assert.deepStrictEqual(modelAction.supports, expectedInfo.supports);
assert.deepStrictEqual(modelAction.versions, expectedInfo.versions);
}
9 changes: 9 additions & 0 deletions js/pnpm-lock.yaml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit d68ed2f

Please sign in to comment.