diff --git a/js/ai/src/generate.ts b/js/ai/src/generate.ts index feee4433c..13da1c66c 100755 --- a/js/ai/src/generate.ts +++ b/js/ai/src/generate.ts @@ -22,7 +22,6 @@ import { runWithContext, runWithStreamingCallback, sentinelNoopStreamingCallback, - stripUndefinedProps, z, } from '@genkit-ai/core'; import { Channel } from '@genkit-ai/core/async'; @@ -46,6 +45,7 @@ import { ModelArgument, ModelMiddleware, Part, + ToolRequestPart, ToolResponsePart, resolveModel, } from './model.js'; @@ -75,7 +75,16 @@ export interface ResumeOptions { * Tools have a `.reply` helper method to construct a reply ToolResponse and validate * the data against its schema. Call `myTool.reply(interruptToolRequest, yourReplyData)`. */ - reply: ToolResponsePart | ToolResponsePart[]; + reply?: ToolResponsePart | ToolResponsePart[]; + /** + * restart will run a tool again with additionally supplied metadata passed through as + * a `resumed` option in the second argument. This allows for scenarios like conditionally + * requesting confirmation of an LLM's tool request. + * + * Tools have a `.restart` helper method to construct a restart ToolRequest. Call + * `myTool.restart(interruptToolRequest, resumeMetadata)`. + */ + restart?: ToolRequestPart | ToolRequestPart[]; /** Additional metadata to annotate the created tool message with in the "resume" key. */ metadata?: Record; } @@ -141,53 +150,6 @@ export interface GenerateOptions< context?: ActionContext; } -/** Amends message history to handle `resume` arguments. Returns the amended history. */ -async function applyResumeOption( - options: GenerateOptions, - messages: MessageData[] -): Promise { - if (!options.resume) return messages; - if ( - messages.at(-1)?.role !== 'model' || - !messages - .at(-1) - ?.content.find((p) => p.toolRequest && p.metadata?.interrupt) - ) { - throw new GenkitError({ - status: 'FAILED_PRECONDITION', - message: `Cannot 'resume' generation unless the previous message is a model message with at least one interrupt.`, - }); - } - const lastModelMessage = messages.at(-1)!; - const toolRequests = lastModelMessage.content.filter((p) => !!p.toolRequest); - - const pendingResponses: ToolResponsePart[] = toolRequests - .filter((t) => !!t.metadata?.pendingOutput) - .map((t) => - stripUndefinedProps({ - toolResponse: { - name: t.toolRequest!.name, - ref: t.toolRequest!.ref, - output: t.metadata!.pendingOutput, - }, - metadata: { source: 'pending' }, - }) - ) as ToolResponsePart[]; - - const reply = Array.isArray(options.resume.reply) - ? options.resume.reply - : [options.resume.reply]; - - const message: MessageData = { - role: 'tool', - content: [...pendingResponses, ...reply], - metadata: { - resume: options.resume.metadata || true, - }, - }; - return [...messages, message]; -} - export async function toGenerateRequest( registry: Registry, options: GenerateOptions @@ -202,8 +164,6 @@ export async function toGenerateRequest( if (options.messages) { messages.push(...options.messages.map((m) => Message.parseData(m))); } - // resuming from interrupts occurs after message history but before user prompt - messages = await applyResumeOption(options, messages); if (options.prompt) { messages.push({ role: 'user', @@ -216,6 +176,19 @@ export async function toGenerateRequest( message: 'at least one message is required in generate request', }); } + if ( + options.resume && + !( + messages.at(-1)?.role === 'model' && + messages.at(-1)?.content.find((p) => !!p.toolRequest) + ) + ) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: `Last message must be a 'model' role with at least one tool request to 'resume' generation.`, + detail: messages.at(-1), + }); + } let tools: Action[] | undefined; if (options.tools) { tools = await resolveTools(registry, options.tools); @@ -386,6 +359,12 @@ export async function generate< format: resolvedOptions.output.format, jsonSchema: resolvedSchema, }, + // coerce reply and restart into arrays for the action schema + resume: resolvedOptions.resume && { + reply: [resolvedOptions.resume.reply || []].flat(), + restart: [resolvedOptions.resume.restart || []].flat(), + metadata: resolvedOptions.resume.metadata, + }, returnToolRequests: resolvedOptions.returnToolRequests, maxTurns: resolvedOptions.maxTurns, }; diff --git a/js/ai/src/generate/action.ts b/js/ai/src/generate/action.ts index 41d84006e..0b694835c 100644 --- a/js/ai/src/generate/action.ts +++ b/js/ai/src/generate/action.ts @@ -17,6 +17,7 @@ import { Action, defineAction, + GenkitError, getStreamingCallback, runWithStreamingCallback, stripUndefinedProps, @@ -25,7 +26,7 @@ import { import { logger } from '@genkit-ai/core/logging'; import { Registry } from '@genkit-ai/core/registry'; import { toJsonSchema } from '@genkit-ai/core/schema'; -import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing'; +import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing'; import { injectInstructions, resolveFormat, @@ -52,12 +53,13 @@ import { ModelMiddleware, ModelRequest, Part, - Role, resolveModel, + Role, } from '../model.js'; -import { ToolAction, resolveTools, toToolDefinition } from '../tool.js'; +import { resolveTools, ToolAction, toToolDefinition } from '../tool.js'; import { assertValidToolNames, + resolveResumeOption, resolveToolRequests, } from './resolve-tool-requests.js'; @@ -225,6 +227,23 @@ async function generate( // check to make sure we don't have overlapping tool names *before* generation await assertValidToolNames(tools); + const { revisedRequest, interruptedResponse } = await resolveResumeOption( + registry, + rawRequest + ); + // NOTE: in the future we should make it possible to interrupt a restart, but + // at the moment it's too complicated because it's not clear how to return a + // response that amends history but doesn't generate a new message, so we throw + if (interruptedResponse) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: + 'One or more tools triggered an interrupt during a restarted execution.', + detail: { message: interruptedResponse.message }, + }); + } + rawRequest = revisedRequest!; + const request = await actionToGenerateRequest( rawRequest, tools, diff --git a/js/ai/src/generate/resolve-tool-requests.ts b/js/ai/src/generate/resolve-tool-requests.ts index c8ead6e28..220350f70 100644 --- a/js/ai/src/generate/resolve-tool-requests.ts +++ b/js/ai/src/generate/resolve-tool-requests.ts @@ -19,11 +19,20 @@ import { logger } from '@genkit-ai/core/logging'; import { Registry } from '@genkit-ai/core/registry'; import { GenerateActionOptions, + GenerateResponseData, MessageData, + Part, + ToolRequestPart, ToolResponsePart, } from '../model.js'; import { isPromptAction } from '../prompt.js'; -import { ToolAction, ToolInterruptError, resolveTools } from '../tool.js'; +import { + ToolAction, + ToolInterruptError, + ToolRunOptions, + isToolRequest, + resolveTools, +} from '../tool.js'; export function toToolMap(tools: ToolAction[]): Record { assertValidToolNames(tools); @@ -52,6 +61,87 @@ export function assertValidToolNames(tools: ToolAction[]) { } } +function toRunOptions(part: ToolRequestPart): ToolRunOptions { + const out: ToolRunOptions = { metadata: part.metadata }; + if (part.metadata?.resumed) out.resumed = part.metadata.resumed; + return out; +} + +export function toPendingOutput( + part: ToolRequestPart, + response: ToolResponsePart +): ToolRequestPart { + return { + ...part, + metadata: { + ...part.metadata, + pendingOutput: response.toolResponse.output, + }, + }; +} + +export async function resolveToolRequest( + rawRequest: GenerateActionOptions, + part: ToolRequestPart, + toolMap: Record, + runOptions?: ToolRunOptions +): Promise<{ + response?: ToolResponsePart; + interrupt?: ToolRequestPart; + preamble?: GenerateActionOptions; +}> { + const tool = toolMap[part.toolRequest.name]; + if (!tool) { + throw new GenkitError({ + status: 'NOT_FOUND', + message: `Tool ${part.toolRequest.name} not found`, + detail: { request: rawRequest }, + }); + } + + // if it's a prompt action, go ahead and render the preamble + if (isPromptAction(tool)) { + const preamble = await tool(part.toolRequest.input); + const response = { + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output: `transferred to ${part.toolRequest.name}`, + }, + }; + + return { preamble, response }; + } + + // otherwise, execute the tool and catch interrupts + try { + const output = await tool(part.toolRequest.input, toRunOptions(part)); + const response = stripUndefinedProps({ + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output, + }, + }); + + return { response }; + } catch (e) { + if (e instanceof ToolInterruptError) { + logger.debug( + `tool '${toolMap[part.toolRequest?.name].__action.name}' triggered an interrupt${e.metadata ? `: ${JSON.stringify(e.metadata)}` : ''}` + ); + const interrupt = { + toolRequest: part.toolRequest, + metadata: { ...part.metadata, interrupt: e.metadata || true }, + }; + + return { interrupt }; + } + + throw e; + } +} + /** * resolveToolRequests is responsible for executing the tools requested by the model for a single turn. it * returns either a toolMessage to append or a revisedModelMessage when an interrupt occurs, and a transferPreamble @@ -81,66 +171,36 @@ export async function resolveToolRequests( revisedModelMessage.content.map(async (part, i) => { if (!part.toolRequest) return; // skip non-tool-request parts - const tool = toolMap[part.toolRequest.name]; - if (!tool) { - throw new GenkitError({ - status: 'NOT_FOUND', - message: `Tool ${part.toolRequest.name} not found`, - detail: { request: rawRequest }, - }); - } + const { preamble, response, interrupt } = await resolveToolRequest( + rawRequest, + part as ToolRequestPart, + toolMap + ); - // if it's a prompt action, go ahead and render the preamble - if (isPromptAction(tool)) { - if (transferPreamble) + if (preamble) { + if (transferPreamble) { throw new GenkitError({ status: 'INVALID_ARGUMENT', message: `Model attempted to transfer to multiple prompt tools.`, }); - transferPreamble = await tool(part.toolRequest.input); - responseParts.push({ - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output: `transferred to ${part.toolRequest.name}`, - }, - }); - return; + } + + transferPreamble = preamble; } - // otherwise, execute the tool and catch interrupts - try { - const output = await tool(part.toolRequest.input, {}); - const responsePart = stripUndefinedProps({ - toolResponse: { - name: part.toolRequest.name, - ref: part.toolRequest.ref, - output, - }, - }); - - revisedModelMessage.content.splice(i, 1, { - ...part, - metadata: { - ...part.metadata, - pendingOutput: responsePart.toolResponse.output, - }, - }); - responseParts.push(responsePart); - } catch (e) { - if (e instanceof ToolInterruptError) { - logger.debug( - `tool '${toolMap[part.toolRequest?.name].__action.name}' triggered an interrupt${e.metadata ? `: ${JSON.stringify(e.metadata)}` : ''}` - ); - revisedModelMessage.content.splice(i, 1, { - toolRequest: part.toolRequest, - metadata: { ...part.metadata, interrupt: e.metadata || true }, - }); - hasInterrupts = true; - return; - } + // this happens for preamble or normal tools + if (response) { + responseParts.push(response!); + revisedModelMessage.content.splice( + i, + 1, + toPendingOutput(part, response) + ); + } - throw e; + if (interrupt) { + revisedModelMessage.content.splice(i, 1, interrupt); + hasInterrupts = true; } }) ); @@ -154,3 +214,232 @@ export async function resolveToolRequests( transferPreamble, }; } + +function findCorrespondingToolRequest( + parts: Part[], + part: ToolRequestPart | ToolResponsePart +): ToolRequestPart | undefined { + const name = part.toolRequest?.name || part.toolResponse?.name; + const ref = part.toolRequest?.ref || part.toolResponse?.ref; + + return parts.find( + (p) => p.toolRequest?.name === name && p.toolRequest?.ref === ref + ) as ToolRequestPart | undefined; +} + +function findCorrespondingToolResponse( + parts: Part[], + part: ToolRequestPart | ToolResponsePart +): ToolResponsePart | undefined { + const name = part.toolRequest?.name || part.toolResponse?.name; + const ref = part.toolRequest?.ref || part.toolResponse?.ref; + + return parts.find( + (p) => p.toolResponse?.name === name && p.toolResponse?.ref === ref + ) as ToolResponsePart | undefined; +} + +async function resolveResumedToolRequest( + rawRequest: GenerateActionOptions, + part: ToolRequestPart, + toolMap: Record +): Promise<{ + toolRequest?: ToolRequestPart; + toolResponse?: ToolResponsePart; + interrupt?: ToolRequestPart; +}> { + if (part.metadata?.pendingOutput) { + const { pendingOutput, ...metadata } = part.metadata; + const toolResponse = { + toolResponse: { + name: part.toolRequest.name, + ref: part.toolRequest.ref, + output: pendingOutput, + }, + metadata: { ...metadata, source: 'pending' }, + }; + + // strip pendingOutput from metadata when returning + return stripUndefinedProps({ + toolResponse, + toolRequest: { ...part, metadata }, + }); + } + + // if there's a corresponding reply, append it to toolResponses + const replyResponse = findCorrespondingToolResponse( + rawRequest.resume?.reply || [], + part + ); + if (replyResponse) { + const toolResponse = replyResponse; + + // remove the 'interrupt' but leave a 'resolvedInterrupt' + const { interrupt, ...metadata } = part.metadata || {}; + return stripUndefinedProps({ + toolResponse, + toolRequest: { + ...part, + metadata: { ...metadata, resolvedInterrupt: interrupt }, + }, + }); + } + + // if there's a corresponding restart, execute then add to toolResponses + const restartRequest = findCorrespondingToolRequest( + rawRequest.resume?.restart || [], + part + ); + if (restartRequest) { + const { response, interrupt, preamble } = await resolveToolRequest( + rawRequest, + restartRequest, + toolMap + ); + + if (preamble) { + throw new GenkitError({ + status: 'INTERNAL', + message: `Prompt tool '${restartRequest.toolRequest.name}' executed inside 'restart' resolution. This should never happen.`, + }); + } + + // if there's a new interrupt, return it + if (interrupt) return { interrupt }; + + if (response) { + const toolResponse = response; + + // remove the 'interrupt' but leave a 'resolvedInterrupt' + const { interrupt, ...metadata } = part.metadata || {}; + return stripUndefinedProps({ + toolResponse, + toolRequest: { + ...part, + metadata: { ...metadata, resolvedInterrupt: interrupt }, + }, + }); + } + } + + throw new GenkitError({ + status: 'INVALID_ARGUMENT', + message: `Unresolved tool request '${part.toolRequest.name}${part.toolRequest.ref ? `#${part.toolRequest.ref}` : ''} was not handled by the 'resume' argument. You must supply replies or restarts for all interrupted tool requests.'`, + }); +} + +/** Amends message history to handle `resume` arguments. Returns the amended history. */ +export async function resolveResumeOption( + registry: Registry, + rawRequest: GenerateActionOptions +): Promise<{ + revisedRequest?: GenerateActionOptions; + interruptedResponse?: GenerateResponseData; +}> { + if (!rawRequest.resume) return { revisedRequest: rawRequest }; // no-op if no resume option + console.log('RESOLVE RESUME OPTION:', rawRequest.resume); + const toolMap = toToolMap(await resolveTools(registry, rawRequest.tools)); + + const messages = rawRequest.messages; + const lastMessage = messages.at(-1); + + if ( + !lastMessage || + lastMessage.role !== 'model' || + !lastMessage.content.find((p) => p.toolRequest && p.metadata?.interrupt) + ) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: `Cannot 'resume' generation unless the previous message is a model message with at least one interrupt.`, + }); + } + + const toolResponses: ToolResponsePart[] = []; + let interrupted = false; + + lastMessage.content = await Promise.all( + lastMessage.content.map(async (part) => { + if (!isToolRequest(part)) return part; + const resolved = await resolveResumedToolRequest( + rawRequest, + part, + toolMap + ); + console.log('RESOLVED TOOL', part.toolRequest.name, 'TO', resolved); + if (resolved.interrupt) { + interrupted = true; + return resolved.interrupt; + } + + toolResponses.push(resolved.toolResponse!); + return resolved.toolRequest!; + }) + ); + + if (interrupted) { + // TODO: figure out how to make this trigger an interrupt response. + return { + interruptedResponse: { + finishReason: 'interrupted', + finishMessage: + 'One or more tools triggered interrupts while resuming generation. The model was not called.', + message: lastMessage, + }, + }; + } + + const numToolRequests = lastMessage.content.filter( + (p) => !!p.toolRequest + ).length; + if (toolResponses.length !== numToolRequests) { + throw new GenkitError({ + status: 'FAILED_PRECONDITION', + message: `Expected ${numToolRequests} tool responses but resolved to ${toolResponses.length}.`, + detail: { toolResponses, message: lastMessage }, + }); + } + + const toolMessage: MessageData = { + role: 'tool', + content: toolResponses, + metadata: { + resumed: rawRequest.resume.metadata || true, + }, + }; + + console.log('CONSTRUCTED A TOOL MESSAGE:', toolMessage.content); + return stripUndefinedProps({ + revisedRequest: { + ...rawRequest, + resume: undefined, + messages: [...messages, toolMessage], + }, + }); +} + +export async function resolveRestartedTools( + registry: Registry, + rawRequest: GenerateActionOptions +): Promise { + const toolMap = toToolMap(await resolveTools(registry, rawRequest.tools)); + const lastMessage = rawRequest.messages.at(-1); + if (!lastMessage || lastMessage.role !== 'model') return []; + + const restarts = lastMessage.content.filter( + (p) => p.toolRequest && p.metadata?.resumed + ) as ToolRequestPart[]; + + return await Promise.all( + restarts.map(async (p) => { + const { response, interrupt } = await resolveToolRequest( + rawRequest, + p, + toolMap + ); + + // this means that it interrupted *again* after the restart + if (interrupt) return interrupt; + return toPendingOutput(p, response!); + }) + ); +} diff --git a/js/ai/src/model.ts b/js/ai/src/model.ts index 284f285fd..b2a8884c0 100644 --- a/js/ai/src/model.ts +++ b/js/ai/src/model.ts @@ -711,6 +711,14 @@ export const GenerateActionOptionsSchema = z.object({ jsonSchema: z.any().optional(), }) .optional(), + /** Options for resuming an interrupted generation. */ + resume: z + .object({ + reply: z.array(ToolResponsePartSchema).optional(), + restart: z.array(ToolRequestPartSchema).optional(), + metadata: z.record(z.any()).optional(), + }) + .optional(), /** When true, return tool calls for manual processing instead of automatically resolving them. */ returnToolRequests: z.boolean().optional(), /** Maximum number of tool call iterations that can be performed in a single generate call (default 5). */ diff --git a/js/ai/src/tool.ts b/js/ai/src/tool.ts index a39cc96da..722a62590 100644 --- a/js/ai/src/tool.ts +++ b/js/ai/src/tool.ts @@ -17,6 +17,7 @@ import { Action, ActionContext, + ActionRunOptions, defineAction, JSONSchema7, stripUndefinedProps, @@ -25,7 +26,12 @@ import { import { Registry } from '@genkit-ai/core/registry'; import { parseSchema, toJsonSchema } from '@genkit-ai/core/schema'; import { setCustomMetadataAttributes } from '@genkit-ai/core/tracing'; -import { ToolDefinition, ToolRequestPart, ToolResponsePart } from './model.js'; +import { + Part, + ToolDefinition, + ToolRequestPart, + ToolResponsePart, +} from './model.js'; import { ExecutablePrompt } from './prompt.js'; /** @@ -34,7 +40,7 @@ import { ExecutablePrompt } from './prompt.js'; export type ToolAction< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, -> = Action & { +> = Action & { __action: { metadata: { type: 'tool'; @@ -55,8 +61,37 @@ export type ToolAction< replyData: z.infer, options?: { metadata?: Record } ): ToolResponsePart; + /** + * restart constructs a tool request corresponding to the provided interrupt tool request + * that will then re-trigger the tool after e.g. a user confirms. The `resumedMetadata` + * supplied to this method will be passed to the tool to allow for custom handling of + * restart logic. + * + * @param interrupt The interrupt tool request you want to restart. + * @param resumedMetadata The metadata you want to provide to the tool to aide in reprocessing. Defaults to `true` if none is supplied. + * @param options Additional options for restarting the tool. + */ + restart( + interrupt: ToolRequestPart, + resumedMetadata?: any, + options?: { + /** + * Replace the existing input arguments to the tool with different ones, for example + * if the user revised an action before confirming. When input is replaced, the existing + * tool request will be amended in the message history. + **/ + replaceInput?: z.infer; + } + ): ToolRequestPart; }; +export interface ToolRunOptions extends ActionRunOptions { + /** If resumed is supplied to a tool at runtime, that means that it was previously interrupted and this is a second */ + resumed?: boolean | Record; + /** The metadata from the tool request that triggered this run. */ + metadata?: Record; +} + /** * Configuration for a tool. */ @@ -200,7 +235,7 @@ export interface ToolFnOptions { export type ToolFn = ( input: z.infer, - ctx: ToolFnOptions + ctx: ToolFnOptions & ToolRunOptions ) => Promise>; /** @@ -220,13 +255,13 @@ export function defineTool( actionType: 'tool', metadata: { ...(config.metadata || {}), type: 'tool' }, }, - (i, { context }) => - fn(i, { + (i, runOptions) => { + return fn(i, { + ...runOptions, + context: { ...runOptions.context }, interrupt: interruptTool, - context: { - ...context, - }, - }) + }); + } ); (a as ToolAction).reply = (interrupt, replyData, options) => { parseSchema(replyData, { @@ -244,6 +279,29 @@ export function defineTool( }, }; }; + + (a as ToolAction).restart = (interrupt, resumedMetadata, options) => { + let replaceInput = options?.replaceInput; + if (replaceInput) { + replaceInput = parseSchema(replaceInput, { + schema: config.inputSchema, + jsonSchema: config.inputJsonSchema, + }); + } + return { + toolRequest: stripUndefinedProps({ + name: interrupt.toolRequest.name, + ref: interrupt.toolRequest.ref, + input: replaceInput || interrupt.toolRequest.input, + }), + metadata: stripUndefinedProps({ + ...interrupt.metadata, + resumed: resumedMetadata || true, + // annotate the original input if replacing it + replacedInput: replaceInput ? interrupt.toolRequest.input : undefined, + }), + }; + }; return a as ToolAction; } @@ -264,6 +322,14 @@ export type InterruptConfig< ) => Record | Promise>); }; +export function isToolRequest(part: Part): part is ToolRequestPart { + return !!part.toolRequest; +} + +export function isToolResponse(part: Part): part is ToolResponsePart { + return !!part.toolResponse; +} + export function defineInterrupt( registry: Registry, config: InterruptConfig diff --git a/js/ai/tests/generate/generate_test.ts b/js/ai/tests/generate/generate_test.ts index ac5399a83..085e69070 100644 --- a/js/ai/tests/generate/generate_test.ts +++ b/js/ai/tests/generate/generate_test.ts @@ -279,105 +279,6 @@ describe('toGenerateRequest', () => { }, throws: 'FAILED_PRECONDITION', }, - { - should: 'add pending responses and interrupt replies to a tool message', - prompt: { - messages: [ - { role: 'user', content: [{ text: 'hey' }] }, - { - role: 'model', - content: [ - { - toolRequest: { name: 'p1', ref: '1', input: { one: '1' } }, - metadata: { - pendingOutput: 'done', - }, - }, - { - toolRequest: { name: 'p2', ref: '2', input: { one: '1' } }, - metadata: { - pendingOutput: 'done2', - }, - }, - { - toolRequest: { name: 'i1', ref: '3', input: { one: '1' } }, - metadata: { - interrupt: true, - }, - }, - { - toolRequest: { name: 'i2', ref: '4', input: { one: '1' } }, - metadata: { - interrupt: { sky: 'blue' }, - }, - }, - ], - }, - ], - resume: { - reply: [ - { toolResponse: { name: 'i1', ref: '3', output: 'done3' } }, - { toolResponse: { name: 'i2', ref: '4', output: 'done4' } }, - ], - }, - }, - expectedOutput: { - config: undefined, - docs: undefined, - output: {}, - tools: [], - messages: [ - { role: 'user', content: [{ text: 'hey' }] }, - { - role: 'model', - content: [ - { - toolRequest: { name: 'p1', ref: '1', input: { one: '1' } }, - metadata: { - pendingOutput: 'done', - }, - }, - { - toolRequest: { name: 'p2', ref: '2', input: { one: '1' } }, - metadata: { - pendingOutput: 'done2', - }, - }, - { - toolRequest: { name: 'i1', ref: '3', input: { one: '1' } }, - metadata: { - interrupt: true, - }, - }, - { - toolRequest: { name: 'i2', ref: '4', input: { one: '1' } }, - metadata: { - interrupt: { sky: 'blue' }, - }, - }, - ], - }, - { - role: 'tool', - metadata: { - resume: true, - }, - content: [ - { - toolResponse: { name: 'p1', ref: '1', output: 'done' }, - metadata: { source: 'pending' }, - }, - { - toolResponse: { name: 'p2', ref: '2', output: 'done2' }, - metadata: { source: 'pending' }, - }, - { toolResponse: { name: 'i1', ref: '3', output: 'done3' } }, - { toolResponse: { name: 'i2', ref: '4', output: 'done4' } }, - ], - }, - ], - }, - }, ]; for (const test of testCases) { it(test.should, async () => { diff --git a/js/ai/tests/tool_test.ts b/js/ai/tests/tool_test.ts index e1ac74098..3d184a517 100644 --- a/js/ai/tests/tool_test.ts +++ b/js/ai/tests/tool_test.ts @@ -183,4 +183,81 @@ describe('defineTool', () => { ); }); }); + + describe('.restart()', () => { + it('constructs a ToolRequestPart', () => { + const t = defineTool( + registry, + { name: 'test', description: 'test' }, + async () => {} + ); + assert.deepStrictEqual( + t.restart({ toolRequest: { name: 'test', input: {} } }), + { + toolRequest: { + name: 'test', + input: {}, + }, + metadata: { + resumed: true, + }, + } + ); + }); + + it('includes metadata', () => { + const t = defineTool( + registry, + { name: 'test', description: 'test' }, + async () => {} + ); + assert.deepStrictEqual( + t.restart( + { toolRequest: { name: 'test', input: {} } }, + { extra: 'data' } + ), + { + toolRequest: { + name: 'test', + input: {}, + }, + metadata: { + resumed: { extra: 'data' }, + }, + } + ); + }); + + it('validates schema', () => { + const t = defineTool( + registry, + { name: 'test', description: 'test', inputSchema: z.number() }, + async (input, { interrupt }) => interrupt() + ); + assert.throws( + () => { + t.restart({ toolRequest: { name: 'test', input: {} } }, undefined, { + replaceInput: 'not_a_number' as any, + }); + }, + { name: 'GenkitError', status: 'INVALID_ARGUMENT' } + ); + + assert.deepStrictEqual( + t.restart({ toolRequest: { name: 'test', input: {} } }, undefined, { + replaceInput: 55, + }), + { + toolRequest: { + name: 'test', + input: 55, + }, + metadata: { + resumed: true, + replacedInput: {}, + }, + } + ); + }); + }); }); diff --git a/js/core/src/action.ts b/js/core/src/action.ts index 1bc7362a3..7683d9893 100644 --- a/js/core/src/action.ts +++ b/js/core/src/action.ts @@ -114,10 +114,8 @@ export type Action< I extends z.ZodTypeAny = z.ZodTypeAny, O extends z.ZodTypeAny = z.ZodTypeAny, S extends z.ZodTypeAny = z.ZodTypeAny, -> = (( - input?: z.infer, - options?: ActionRunOptions -) => Promise>) & { + RunOptions extends ActionRunOptions = ActionRunOptions, +> = ((input?: z.infer, options?: RunOptions) => Promise>) & { __action: ActionMetadata; __registry: Registry; run( @@ -313,6 +311,7 @@ export function action< try { const actionFn = () => fn(input, { + ...options, // Context can either be explicitly set, or inherited from the parent action. context: options?.context ?? getContext(registry), sendChunk: options?.onChunk ?? sentinelNoopStreamingCallback, diff --git a/js/genkit/tests/generate_test.ts b/js/genkit/tests/generate_test.ts index 81285862d..4fe0278d3 100644 --- a/js/genkit/tests/generate_test.ts +++ b/js/genkit/tests/generate_test.ts @@ -14,6 +14,7 @@ * limitations under the License. */ +import { MessageData } from '@genkit-ai/ai'; import { z } from '@genkit-ai/core'; import * as assert from 'assert'; import { beforeEach, describe, it } from 'node:test'; @@ -645,6 +646,13 @@ describe('generate', () => { async (input, { interrupt }) => interrupt({ confirm: 'is it a banana?' }) ); + ai.defineTool( + { name: 'resumableTool', description: 'description' }, + async (input, { interrupt, resumed }) => { + if ((resumed as any)?.status === 'ok') return true; + return interrupt(); + } + ); // first response be tools call, the subsequent just text response from agent b. let reqCounter = 0; @@ -669,7 +677,14 @@ describe('generate', () => { toolRequest: { name: 'simpleTool', input: { name: 'foo' }, - ref: 'ref123', + ref: 'ref456', + }, + }, + { + toolRequest: { + name: 'resumableTool', + input: { doIt: true }, + ref: 'ref789', }, }, ] @@ -680,7 +695,7 @@ describe('generate', () => { const response = await ai.generate({ prompt: 'call the tool', - tools: ['interruptingTool', 'simpleTool'], + tools: ['interruptingTool', 'simpleTool', 'resumableTool'], }); assert.strictEqual(reqCounter, 1); @@ -703,12 +718,24 @@ describe('generate', () => { name: 'foo', }, name: 'simpleTool', - ref: 'ref123', + ref: 'ref456', }, metadata: { pendingOutput: 'response: foo', }, }, + { + metadata: { + interrupt: true, + }, + toolRequest: { + name: 'resumableTool', + ref: 'ref789', + input: { + doIt: true, + }, + }, + }, ]); assert.deepStrictEqual(response.message?.toJSON(), { role: 'model', @@ -734,12 +761,24 @@ describe('generate', () => { name: 'foo', }, name: 'simpleTool', - ref: 'ref123', + ref: 'ref456', }, metadata: { pendingOutput: 'response: foo', }, }, + { + metadata: { + interrupt: true, + }, + toolRequest: { + name: 'resumableTool', + ref: 'ref789', + input: { + doIt: true, + }, + }, + }, ], }); assert.deepStrictEqual(pm.lastRequest, { @@ -772,8 +811,148 @@ describe('generate', () => { $schema: 'http://json-schema.org/draft-07/schema#', }, }, + { + description: 'description', + inputSchema: { + $schema: 'http://json-schema.org/draft-07/schema#', + }, + name: 'resumableTool', + outputSchema: { + $schema: 'http://json-schema.org/draft-07/schema#', + }, + }, ], }); }); + + it('can resume generation', { only: true }, async () => { + const interrupter = ai.defineInterrupt({ + name: 'interrupter', + description: 'always interrupts', + }); + const truth = ai.defineTool( + { name: 'truth', description: 'always returns true' }, + async () => true + ); + const resumable = ai.defineTool( + { + name: 'resumable', + description: 'interrupts unless resumed with {status: "ok"}', + }, + async (_input, { interrupt, resumed }) => { + console.log('RESUMABLE TOOL CALLED WITH:', resumed); + if ((resumed as any)?.status === 'ok') return true; + return interrupt(); + } + ); + + const messages: MessageData[] = [ + { role: 'user', content: [{ text: 'hello' }] }, + { + role: 'model', + content: [ + { + toolRequest: { name: 'interrupter', input: {} }, + metadata: { interrupt: true }, + }, + { + toolRequest: { name: 'truth', input: {} }, + metadata: { pendingOutput: true }, + }, + { + toolRequest: { name: 'resumable', input: {} }, + metadata: { interrupt: true }, + }, + ], + }, + ]; + + const response = await ai.generate({ + model: 'echoModel', + messages, + tools: [interrupter, resumable, truth], + resume: { + reply: interrupter.reply( + { + toolRequest: { name: 'interrupter', input: {} }, + metadata: { interrupt: true }, + }, + 23 + ), + restart: resumable.restart( + { + toolRequest: { name: 'resumable', input: {} }, + metadata: { interrupt: true }, + }, + { status: 'ok' } + ), + }, + }); + + const revisedModelMessage = response.messages.at(-3); + const toolMessage = response.messages.at(-2); + + assert.deepStrictEqual( + revisedModelMessage?.content, + [ + { + metadata: { + resolvedInterrupt: true, + }, + toolRequest: { + input: {}, + name: 'interrupter', + }, + }, + { + metadata: {}, + toolRequest: { + input: {}, + name: 'truth', + }, + }, + { + metadata: { + resolvedInterrupt: true, + }, + toolRequest: { + input: {}, + name: 'resumable', + }, + }, + ], + 'resuming amends the model message to resolve interrupts' + ); + assert.deepStrictEqual( + toolMessage?.content, + [ + { + metadata: { + reply: true, + }, + toolResponse: { + name: 'interrupter', + output: 23, + }, + }, + { + metadata: { + source: 'pending', + }, + toolResponse: { + name: 'truth', + output: true, + }, + }, + { + toolResponse: { + name: 'resumable', + output: true, + }, + }, + ], + 'resuming generates a tool message containing all expected responses' + ); + }); }); }); diff --git a/js/plugins/next/src/index.ts b/js/plugins/next/src/index.ts index 4d4bb6194..8cfce3281 100644 --- a/js/plugins/next/src/index.ts +++ b/js/plugins/next/src/index.ts @@ -17,8 +17,10 @@ import type { Action } from '@genkit-ai/core'; import { NextRequest, NextResponse } from 'next/server.js'; -const appRoute = - >(action: A) => +const appRoute: >( + action: A +) => (req: NextRequest) => Promise = + (action) => async (req: NextRequest): Promise => { const { data } = await req.json(); if (req.headers.get('accept') !== 'text/event-stream') { diff --git a/js/pnpm-lock.yaml b/js/pnpm-lock.yaml index 98f8ef086..2ca35d8c5 100644 --- a/js/pnpm-lock.yaml +++ b/js/pnpm-lock.yaml @@ -1373,7 +1373,7 @@ importers: version: link:../../plugins/ollama genkitx-openai: specifier: ^0.10.1 - version: 0.10.1(@genkit-ai/ai@1.0.0-rc.10)(@genkit-ai/core@1.0.0-rc.10) + version: 0.10.1(@genkit-ai/ai@1.0.0-rc.12)(@genkit-ai/core@1.0.0-rc.12) devDependencies: rimraf: specifier: ^6.0.1 @@ -2505,11 +2505,11 @@ packages: '@firebase/webchannel-wrapper@1.0.3': resolution: {integrity: sha512-2xCRM9q9FlzGZCdgDMJwc0gyUkWFtkosy7Xxr6sFgQwn+wMNIWd7xIvYNauU1r64B5L5rsGKy/n9TKJ0aAFeqQ==} - '@genkit-ai/ai@1.0.0-rc.10': - resolution: {integrity: sha512-FIwseXpGFChZAoGXJiSDFSqEUam5dK0SR6EExtukBfZg4BGLz5e0XjzqmPvf8HW3iCWF/3MOZEBx7fVrCW60+g==} + '@genkit-ai/ai@1.0.0-rc.12': + resolution: {integrity: sha512-1hofDfuTEVDfryy7klgeJwCyXOp4wEZ+VPU34Dts9S5PYoN1nb23KAhnz+6dHgM42REM1PI5Zh2Wq8O10BXIww==} - '@genkit-ai/core@1.0.0-rc.10': - resolution: {integrity: sha512-KA+z3oqOsL2w65cMchfmS1ont1gkm0WYOQwzm9cxmyjBbbGP1Ndq3udCIUJM410emITuU4wHEDu1dokDoSaOWA==} + '@genkit-ai/core@1.0.0-rc.12': + resolution: {integrity: sha512-3K7GVXR1vnJu2ANICbojxKpBoxIncuGIvn0NJ7T+s3/IOBQcusmkUyJoc4TdeRsZdX693J+oMUSx+ZE18s28eQ==} '@gerrit0/mini-shiki@1.24.4': resolution: {integrity: sha512-YEHW1QeAg6UmxEmswiQbOVEg1CW22b1XUD/lNTliOsu0LD0wqoyleFMnmbTp697QE0pcadQiR5cVtbbAPncvpw==} @@ -7822,9 +7822,9 @@ snapshots: '@firebase/webchannel-wrapper@1.0.3': {} - '@genkit-ai/ai@1.0.0-rc.10': + '@genkit-ai/ai@1.0.0-rc.12': dependencies: - '@genkit-ai/core': 1.0.0-rc.10 + '@genkit-ai/core': 1.0.0-rc.12 '@opentelemetry/api': 1.9.0 '@types/node': 20.16.9 colorette: 2.0.20 @@ -7836,7 +7836,7 @@ snapshots: transitivePeerDependencies: - supports-color - '@genkit-ai/core@1.0.0-rc.10': + '@genkit-ai/core@1.0.0-rc.12': dependencies: '@opentelemetry/api': 1.9.0 '@opentelemetry/context-async-hooks': 1.25.1(@opentelemetry/api@1.9.0) @@ -10578,10 +10578,10 @@ snapshots: - encoding - supports-color - genkitx-openai@0.10.1(@genkit-ai/ai@1.0.0-rc.10)(@genkit-ai/core@1.0.0-rc.10): + genkitx-openai@0.10.1(@genkit-ai/ai@1.0.0-rc.12)(@genkit-ai/core@1.0.0-rc.12): dependencies: - '@genkit-ai/ai': 1.0.0-rc.10 - '@genkit-ai/core': 1.0.0-rc.10 + '@genkit-ai/ai': 1.0.0-rc.12 + '@genkit-ai/core': 1.0.0-rc.12 openai: 4.53.0(encoding@0.1.13) zod: 3.24.1 transitivePeerDependencies: