Skip to content

Commit

Permalink
feat(js/ai): Adds "restart" for tool interrupts with flexible resumin…
Browse files Browse the repository at this point in the history
…g. (#1693)
  • Loading branch information
mbleigh authored Jan 31, 2025
1 parent d2c0617 commit b52ec84
Show file tree
Hide file tree
Showing 11 changed files with 755 additions and 236 deletions.
81 changes: 30 additions & 51 deletions js/ai/src/generate.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import {
runWithContext,
runWithStreamingCallback,
sentinelNoopStreamingCallback,
stripUndefinedProps,
z,
} from '@genkit-ai/core';
import { Channel } from '@genkit-ai/core/async';
Expand All @@ -46,6 +45,7 @@ import {
ModelArgument,
ModelMiddleware,
Part,
ToolRequestPart,
ToolResponsePart,
resolveModel,
} from './model.js';
Expand Down Expand Up @@ -75,7 +75,16 @@ export interface ResumeOptions {
* Tools have a `.reply` helper method to construct a reply ToolResponse and validate
* the data against its schema. Call `myTool.reply(interruptToolRequest, yourReplyData)`.
*/
reply: ToolResponsePart | ToolResponsePart[];
reply?: ToolResponsePart | ToolResponsePart[];
/**
* restart will run a tool again with additionally supplied metadata passed through as
* a `resumed` option in the second argument. This allows for scenarios like conditionally
* requesting confirmation of an LLM's tool request.
*
* Tools have a `.restart` helper method to construct a restart ToolRequest. Call
* `myTool.restart(interruptToolRequest, resumeMetadata)`.
*/
restart?: ToolRequestPart | ToolRequestPart[];
/** Additional metadata to annotate the created tool message with in the "resume" key. */
metadata?: Record<string, any>;
}
Expand Down Expand Up @@ -141,53 +150,6 @@ export interface GenerateOptions<
context?: ActionContext;
}

/** Amends message history to handle `resume` arguments. Returns the amended history. */
async function applyResumeOption(
options: GenerateOptions,
messages: MessageData[]
): Promise<MessageData[]> {
if (!options.resume) return messages;
if (
messages.at(-1)?.role !== 'model' ||
!messages
.at(-1)
?.content.find((p) => p.toolRequest && p.metadata?.interrupt)
) {
throw new GenkitError({
status: 'FAILED_PRECONDITION',
message: `Cannot 'resume' generation unless the previous message is a model message with at least one interrupt.`,
});
}
const lastModelMessage = messages.at(-1)!;
const toolRequests = lastModelMessage.content.filter((p) => !!p.toolRequest);

const pendingResponses: ToolResponsePart[] = toolRequests
.filter((t) => !!t.metadata?.pendingOutput)
.map((t) =>
stripUndefinedProps({
toolResponse: {
name: t.toolRequest!.name,
ref: t.toolRequest!.ref,
output: t.metadata!.pendingOutput,
},
metadata: { source: 'pending' },
})
) as ToolResponsePart[];

const reply = Array.isArray(options.resume.reply)
? options.resume.reply
: [options.resume.reply];

const message: MessageData = {
role: 'tool',
content: [...pendingResponses, ...reply],
metadata: {
resume: options.resume.metadata || true,
},
};
return [...messages, message];
}

export async function toGenerateRequest(
registry: Registry,
options: GenerateOptions
Expand All @@ -202,8 +164,6 @@ export async function toGenerateRequest(
if (options.messages) {
messages.push(...options.messages.map((m) => Message.parseData(m)));
}
// resuming from interrupts occurs after message history but before user prompt
messages = await applyResumeOption(options, messages);
if (options.prompt) {
messages.push({
role: 'user',
Expand All @@ -216,6 +176,19 @@ export async function toGenerateRequest(
message: 'at least one message is required in generate request',
});
}
if (
options.resume &&
!(
messages.at(-1)?.role === 'model' &&
messages.at(-1)?.content.find((p) => !!p.toolRequest)
)
) {
throw new GenkitError({
status: 'FAILED_PRECONDITION',
message: `Last message must be a 'model' role with at least one tool request to 'resume' generation.`,
detail: messages.at(-1),
});
}
let tools: Action<any, any>[] | undefined;
if (options.tools) {
tools = await resolveTools(registry, options.tools);
Expand Down Expand Up @@ -386,6 +359,12 @@ export async function generate<
format: resolvedOptions.output.format,
jsonSchema: resolvedSchema,
},
// coerce reply and restart into arrays for the action schema
resume: resolvedOptions.resume && {
reply: [resolvedOptions.resume.reply || []].flat(),
restart: [resolvedOptions.resume.restart || []].flat(),
metadata: resolvedOptions.resume.metadata,
},
returnToolRequests: resolvedOptions.returnToolRequests,
maxTurns: resolvedOptions.maxTurns,
};
Expand Down
25 changes: 22 additions & 3 deletions js/ai/src/generate/action.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import {
Action,
defineAction,
GenkitError,
getStreamingCallback,
runWithStreamingCallback,
stripUndefinedProps,
Expand All @@ -25,7 +26,7 @@ import {
import { logger } from '@genkit-ai/core/logging';
import { Registry } from '@genkit-ai/core/registry';
import { toJsonSchema } from '@genkit-ai/core/schema';
import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing';
import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing';
import {
injectInstructions,
resolveFormat,
Expand All @@ -52,12 +53,13 @@ import {
ModelMiddleware,
ModelRequest,
Part,
Role,
resolveModel,
Role,
} from '../model.js';
import { ToolAction, resolveTools, toToolDefinition } from '../tool.js';
import { resolveTools, ToolAction, toToolDefinition } from '../tool.js';
import {
assertValidToolNames,
resolveResumeOption,
resolveToolRequests,
} from './resolve-tool-requests.js';

Expand Down Expand Up @@ -225,6 +227,23 @@ async function generate(
// check to make sure we don't have overlapping tool names *before* generation
await assertValidToolNames(tools);

const { revisedRequest, interruptedResponse } = await resolveResumeOption(
registry,
rawRequest
);
// NOTE: in the future we should make it possible to interrupt a restart, but
// at the moment it's too complicated because it's not clear how to return a
// response that amends history but doesn't generate a new message, so we throw
if (interruptedResponse) {
throw new GenkitError({
status: 'FAILED_PRECONDITION',
message:
'One or more tools triggered an interrupt during a restarted execution.',
detail: { message: interruptedResponse.message },
});
}
rawRequest = revisedRequest!;

const request = await actionToGenerateRequest(
rawRequest,
tools,
Expand Down
Loading

0 comments on commit b52ec84

Please sign in to comment.