Skip to content

Commit b52ec84

Browse files
authored
feat(js/ai): Adds "restart" for tool interrupts with flexible resuming. (#1693)
1 parent d2c0617 commit b52ec84

File tree

11 files changed

+755
-236
lines changed

11 files changed

+755
-236
lines changed

js/ai/src/generate.ts

Lines changed: 30 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ import {
2222
runWithContext,
2323
runWithStreamingCallback,
2424
sentinelNoopStreamingCallback,
25-
stripUndefinedProps,
2625
z,
2726
} from '@genkit-ai/core';
2827
import { Channel } from '@genkit-ai/core/async';
@@ -46,6 +45,7 @@ import {
4645
ModelArgument,
4746
ModelMiddleware,
4847
Part,
48+
ToolRequestPart,
4949
ToolResponsePart,
5050
resolveModel,
5151
} from './model.js';
@@ -75,7 +75,16 @@ export interface ResumeOptions {
7575
* Tools have a `.reply` helper method to construct a reply ToolResponse and validate
7676
* the data against its schema. Call `myTool.reply(interruptToolRequest, yourReplyData)`.
7777
*/
78-
reply: ToolResponsePart | ToolResponsePart[];
78+
reply?: ToolResponsePart | ToolResponsePart[];
79+
/**
80+
* restart will run a tool again with additionally supplied metadata passed through as
81+
* a `resumed` option in the second argument. This allows for scenarios like conditionally
82+
* requesting confirmation of an LLM's tool request.
83+
*
84+
* Tools have a `.restart` helper method to construct a restart ToolRequest. Call
85+
* `myTool.restart(interruptToolRequest, resumeMetadata)`.
86+
*/
87+
restart?: ToolRequestPart | ToolRequestPart[];
7988
/** Additional metadata to annotate the created tool message with in the "resume" key. */
8089
metadata?: Record<string, any>;
8190
}
@@ -141,53 +150,6 @@ export interface GenerateOptions<
141150
context?: ActionContext;
142151
}
143152

144-
/** Amends message history to handle `resume` arguments. Returns the amended history. */
145-
async function applyResumeOption(
146-
options: GenerateOptions,
147-
messages: MessageData[]
148-
): Promise<MessageData[]> {
149-
if (!options.resume) return messages;
150-
if (
151-
messages.at(-1)?.role !== 'model' ||
152-
!messages
153-
.at(-1)
154-
?.content.find((p) => p.toolRequest && p.metadata?.interrupt)
155-
) {
156-
throw new GenkitError({
157-
status: 'FAILED_PRECONDITION',
158-
message: `Cannot 'resume' generation unless the previous message is a model message with at least one interrupt.`,
159-
});
160-
}
161-
const lastModelMessage = messages.at(-1)!;
162-
const toolRequests = lastModelMessage.content.filter((p) => !!p.toolRequest);
163-
164-
const pendingResponses: ToolResponsePart[] = toolRequests
165-
.filter((t) => !!t.metadata?.pendingOutput)
166-
.map((t) =>
167-
stripUndefinedProps({
168-
toolResponse: {
169-
name: t.toolRequest!.name,
170-
ref: t.toolRequest!.ref,
171-
output: t.metadata!.pendingOutput,
172-
},
173-
metadata: { source: 'pending' },
174-
})
175-
) as ToolResponsePart[];
176-
177-
const reply = Array.isArray(options.resume.reply)
178-
? options.resume.reply
179-
: [options.resume.reply];
180-
181-
const message: MessageData = {
182-
role: 'tool',
183-
content: [...pendingResponses, ...reply],
184-
metadata: {
185-
resume: options.resume.metadata || true,
186-
},
187-
};
188-
return [...messages, message];
189-
}
190-
191153
export async function toGenerateRequest(
192154
registry: Registry,
193155
options: GenerateOptions
@@ -202,8 +164,6 @@ export async function toGenerateRequest(
202164
if (options.messages) {
203165
messages.push(...options.messages.map((m) => Message.parseData(m)));
204166
}
205-
// resuming from interrupts occurs after message history but before user prompt
206-
messages = await applyResumeOption(options, messages);
207167
if (options.prompt) {
208168
messages.push({
209169
role: 'user',
@@ -216,6 +176,19 @@ export async function toGenerateRequest(
216176
message: 'at least one message is required in generate request',
217177
});
218178
}
179+
if (
180+
options.resume &&
181+
!(
182+
messages.at(-1)?.role === 'model' &&
183+
messages.at(-1)?.content.find((p) => !!p.toolRequest)
184+
)
185+
) {
186+
throw new GenkitError({
187+
status: 'FAILED_PRECONDITION',
188+
message: `Last message must be a 'model' role with at least one tool request to 'resume' generation.`,
189+
detail: messages.at(-1),
190+
});
191+
}
219192
let tools: Action<any, any>[] | undefined;
220193
if (options.tools) {
221194
tools = await resolveTools(registry, options.tools);
@@ -386,6 +359,12 @@ export async function generate<
386359
format: resolvedOptions.output.format,
387360
jsonSchema: resolvedSchema,
388361
},
362+
// coerce reply and restart into arrays for the action schema
363+
resume: resolvedOptions.resume && {
364+
reply: [resolvedOptions.resume.reply || []].flat(),
365+
restart: [resolvedOptions.resume.restart || []].flat(),
366+
metadata: resolvedOptions.resume.metadata,
367+
},
389368
returnToolRequests: resolvedOptions.returnToolRequests,
390369
maxTurns: resolvedOptions.maxTurns,
391370
};

js/ai/src/generate/action.ts

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import {
1818
Action,
1919
defineAction,
20+
GenkitError,
2021
getStreamingCallback,
2122
runWithStreamingCallback,
2223
stripUndefinedProps,
@@ -25,7 +26,7 @@ import {
2526
import { logger } from '@genkit-ai/core/logging';
2627
import { Registry } from '@genkit-ai/core/registry';
2728
import { toJsonSchema } from '@genkit-ai/core/schema';
28-
import { SPAN_TYPE_ATTR, runInNewSpan } from '@genkit-ai/core/tracing';
29+
import { runInNewSpan, SPAN_TYPE_ATTR } from '@genkit-ai/core/tracing';
2930
import {
3031
injectInstructions,
3132
resolveFormat,
@@ -52,12 +53,13 @@ import {
5253
ModelMiddleware,
5354
ModelRequest,
5455
Part,
55-
Role,
5656
resolveModel,
57+
Role,
5758
} from '../model.js';
58-
import { ToolAction, resolveTools, toToolDefinition } from '../tool.js';
59+
import { resolveTools, ToolAction, toToolDefinition } from '../tool.js';
5960
import {
6061
assertValidToolNames,
62+
resolveResumeOption,
6163
resolveToolRequests,
6264
} from './resolve-tool-requests.js';
6365

@@ -225,6 +227,23 @@ async function generate(
225227
// check to make sure we don't have overlapping tool names *before* generation
226228
await assertValidToolNames(tools);
227229

230+
const { revisedRequest, interruptedResponse } = await resolveResumeOption(
231+
registry,
232+
rawRequest
233+
);
234+
// NOTE: in the future we should make it possible to interrupt a restart, but
235+
// at the moment it's too complicated because it's not clear how to return a
236+
// response that amends history but doesn't generate a new message, so we throw
237+
if (interruptedResponse) {
238+
throw new GenkitError({
239+
status: 'FAILED_PRECONDITION',
240+
message:
241+
'One or more tools triggered an interrupt during a restarted execution.',
242+
detail: { message: interruptedResponse.message },
243+
});
244+
}
245+
rawRequest = revisedRequest!;
246+
228247
const request = await actionToGenerateRequest(
229248
rawRequest,
230249
tools,

0 commit comments

Comments
 (0)