diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 30e1ac7dd..b04f5c99b 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -512,7 +512,7 @@ const runtimeContextAslKey = 'core.action.runtimeContext'; * Checks whether the caller is currently in the runtime context of an action. */ export function isInRuntimeContext(registry: Registry) { - return !!registry.asyncStore.getStore(runtimeContextAslKey); + return registry.asyncStore.getStore(runtimeContextAslKey) === 'runtime'; } /** @@ -521,3 +521,13 @@ export function isInRuntimeContext(registry: Registry) { export function runInActionRuntimeContext(registry: Registry, fn: () => R) { return registry.asyncStore.run(runtimeContextAslKey, 'runtime', fn); } + +/** + * Execute the provided function outside the action runtime context. + */ +export function runOutsideActionRuntimeContext( + registry: Registry, + fn: () => R +) { + return registry.asyncStore.run(runtimeContextAslKey, 'outside', fn); +} diff --git a/js/core/src/plugin.ts b/js/core/src/plugin.ts index ae5bfe340..5f5b9d1d2 100644 --- a/js/core/src/plugin.ts +++ b/js/core/src/plugin.ts @@ -44,52 +44,4 @@ export interface InitializedPlugin { telemetry?: any; } -type PluginInit = ( - ...args: any[] -) => InitializedPlugin | void | Promise; - export type Plugin = (...args: T) => PluginProvider; - -/** - * Defines a Genkit plugin. - */ -export function genkitPlugin( - pluginName: string, - initFn: T -): Plugin> { - return (...args: Parameters) => ({ - name: pluginName, - initializer: async () => { - const initializedPlugin = (await initFn(...args)) || {}; - validatePluginActions(pluginName, initializedPlugin); - return initializedPlugin; - }, - }); -} - -function validatePluginActions(pluginName: string, plugin?: InitializedPlugin) { - if (!plugin) { - return; - } - - plugin.models?.forEach((model) => validateNaming(pluginName, model)); - plugin.retrievers?.forEach((retriever) => - validateNaming(pluginName, retriever) - ); - plugin.embedders?.forEach((embedder) => validateNaming(pluginName, embedder)); - plugin.indexers?.forEach((indexer) => validateNaming(pluginName, indexer)); - plugin.evaluators?.forEach((evaluator) => - validateNaming(pluginName, evaluator) - ); -} - -function validateNaming( - pluginName: string, - action: Action -) { - const nameParts = action.__action.name.split('/'); - if (nameParts[0] !== pluginName) { - const err = `Plugin name ${pluginName} not found in action name ${action.__action.name}. Action names must follow the pattern {pluginName}/{actionName}`; - throw new Error(err); - } -} diff --git a/js/core/src/registry.ts b/js/core/src/registry.ts index 1768cb0cf..96d95ada7 100644 --- a/js/core/src/registry.ts +++ b/js/core/src/registry.ts @@ -17,7 +17,7 @@ import { Dotprompt } from 'dotprompt'; import { AsyncLocalStorage } from 'node:async_hooks'; import * as z from 'zod'; -import { Action } from './action.js'; +import { Action, runOutsideActionRuntimeContext } from './action.js'; import { GenkitError } from './error.js'; import { logger } from './logging.js'; import { PluginProvider } from './plugin.js'; @@ -229,7 +229,9 @@ export class Registry { */ async initializePlugin(name: string) { if (this.pluginsByName[name]) { - return await this.pluginsByName[name].initializer(); + return await runOutsideActionRuntimeContext(this, () => + this.pluginsByName[name].initializer() + ); } } diff --git a/js/core/tests/action_test.ts b/js/core/tests/action_test.ts index 76f5c1d84..bdaa419c1 100644 --- a/js/core/tests/action_test.ts +++ b/js/core/tests/action_test.ts @@ -150,7 +150,7 @@ describe('action', () => { registry, { name: 'child', actionType: 'custom' }, async (_, { context }) => { - return `hi ${context.auth.email}`; + return `hi ${context?.auth?.email}`; } ); const parent = defineAction( diff --git a/js/core/tests/registry_test.ts b/js/core/tests/registry_test.ts index 357620ca6..24d15e4ba 100644 --- a/js/core/tests/registry_test.ts +++ b/js/core/tests/registry_test.ts @@ -16,7 +16,11 @@ import * as assert from 'assert'; import { beforeEach, describe, it } from 'node:test'; -import { action } from '../src/action.js'; +import { + action, + defineAction, + runInActionRuntimeContext, +} from '../src/action.js'; import { Registry } from '../src/registry.js'; describe('registry class', () => { @@ -103,6 +107,32 @@ describe('registry class', () => { }); }); + it('should allow plugin initialization from runtime context', async () => { + let fooInitialized = false; + registry.registerPluginProvider('foo', { + name: 'foo', + async initializer() { + defineAction( + registry, + { + actionType: 'model', + name: 'foo/something', + }, + async () => null + ); + fooInitialized = true; + return {}; + }, + }); + + const action = await runInActionRuntimeContext(registry, () => + registry.lookupAction('/model/foo/something') + ); + + assert.ok(action); + assert.ok(fooInitialized); + }); + it('returns all registered actions, including parent', async () => { const child = Registry.withParent(registry);