Skip to content

Commit

Permalink
fix(ai/core): correctly allow plugin initialization from runtime cont…
Browse files Browse the repository at this point in the history
…ext (#1678)
  • Loading branch information
pavelgj authored Jan 29, 2025
1 parent 241eca1 commit f771771
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 53 deletions.
12 changes: 11 additions & 1 deletion js/core/src/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ const runtimeContextAslKey = 'core.action.runtimeContext';
* Checks whether the caller is currently in the runtime context of an action.
*/
export function isInRuntimeContext(registry: Registry) {
return !!registry.asyncStore.getStore(runtimeContextAslKey);
return registry.asyncStore.getStore(runtimeContextAslKey) === 'runtime';
}

/**
Expand All @@ -521,3 +521,13 @@ export function isInRuntimeContext(registry: Registry) {
export function runInActionRuntimeContext<R>(registry: Registry, fn: () => R) {
return registry.asyncStore.run(runtimeContextAslKey, 'runtime', fn);
}

/**
* Execute the provided function outside the action runtime context.
*/
export function runOutsideActionRuntimeContext<R>(
registry: Registry,
fn: () => R
) {
return registry.asyncStore.run(runtimeContextAslKey, 'outside', fn);
}
48 changes: 0 additions & 48 deletions js/core/src/plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,52 +44,4 @@ export interface InitializedPlugin {
telemetry?: any;
}

type PluginInit = (
...args: any[]
) => InitializedPlugin | void | Promise<InitializedPlugin | void>;

export type Plugin<T extends any[]> = (...args: T) => PluginProvider;

/**
* Defines a Genkit plugin.
*/
export function genkitPlugin<T extends PluginInit>(
pluginName: string,
initFn: T
): Plugin<Parameters<T>> {
return (...args: Parameters<T>) => ({
name: pluginName,
initializer: async () => {
const initializedPlugin = (await initFn(...args)) || {};
validatePluginActions(pluginName, initializedPlugin);
return initializedPlugin;
},
});
}

function validatePluginActions(pluginName: string, plugin?: InitializedPlugin) {
if (!plugin) {
return;
}

plugin.models?.forEach((model) => validateNaming(pluginName, model));
plugin.retrievers?.forEach((retriever) =>
validateNaming(pluginName, retriever)
);
plugin.embedders?.forEach((embedder) => validateNaming(pluginName, embedder));
plugin.indexers?.forEach((indexer) => validateNaming(pluginName, indexer));
plugin.evaluators?.forEach((evaluator) =>
validateNaming(pluginName, evaluator)
);
}

function validateNaming(
pluginName: string,
action: Action<z.ZodTypeAny, z.ZodTypeAny>
) {
const nameParts = action.__action.name.split('/');
if (nameParts[0] !== pluginName) {
const err = `Plugin name ${pluginName} not found in action name ${action.__action.name}. Action names must follow the pattern {pluginName}/{actionName}`;
throw new Error(err);
}
}
6 changes: 4 additions & 2 deletions js/core/src/registry.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import { Dotprompt } from 'dotprompt';
import { AsyncLocalStorage } from 'node:async_hooks';
import * as z from 'zod';
import { Action } from './action.js';
import { Action, runOutsideActionRuntimeContext } from './action.js';
import { GenkitError } from './error.js';
import { logger } from './logging.js';
import { PluginProvider } from './plugin.js';
Expand Down Expand Up @@ -229,7 +229,9 @@ export class Registry {
*/
async initializePlugin(name: string) {
if (this.pluginsByName[name]) {
return await this.pluginsByName[name].initializer();
return await runOutsideActionRuntimeContext(this, () =>
this.pluginsByName[name].initializer()
);
}
}

Expand Down
2 changes: 1 addition & 1 deletion js/core/tests/action_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ describe('action', () => {
registry,
{ name: 'child', actionType: 'custom' },
async (_, { context }) => {
return `hi ${context.auth.email}`;
return `hi ${context?.auth?.email}`;
}
);
const parent = defineAction(
Expand Down
32 changes: 31 additions & 1 deletion js/core/tests/registry_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@

import * as assert from 'assert';
import { beforeEach, describe, it } from 'node:test';
import { action } from '../src/action.js';
import {
action,
defineAction,
runInActionRuntimeContext,
} from '../src/action.js';
import { Registry } from '../src/registry.js';

describe('registry class', () => {
Expand Down Expand Up @@ -103,6 +107,32 @@ describe('registry class', () => {
});
});

it('should allow plugin initialization from runtime context', async () => {
let fooInitialized = false;
registry.registerPluginProvider('foo', {
name: 'foo',
async initializer() {
defineAction(
registry,
{
actionType: 'model',
name: 'foo/something',
},
async () => null
);
fooInitialized = true;
return {};
},
});

const action = await runInActionRuntimeContext(registry, () =>
registry.lookupAction('/model/foo/something')
);

assert.ok(action);
assert.ok(fooInitialized);
});

it('returns all registered actions, including parent', async () => {
const child = Registry.withParent(registry);

Expand Down

0 comments on commit f771771

Please sign in to comment.