Skip to content

Commit d918681

Browse files
authored
feat(js/ai): stream tools responses (#1614)
1 parent 7d22c8b commit d918681

File tree

14 files changed

+221
-56
lines changed

14 files changed

+221
-56
lines changed

genkit-tools/common/src/types/model.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ export type GenerateResponseData = z.infer<typeof GenerateResponseSchema>;
245245

246246
/** ModelResponseChunkSchema represents a chunk of content to stream to the client. */
247247
export const ModelResponseChunkSchema = z.object({
248+
role: RoleSchema.optional(),
249+
/** index of the message this chunk belongs to. */
250+
index: z.number().optional(),
248251
/** The chunk of content to stream right now. */
249252
content: z.array(PartSchema),
250253
/** Model-specific extra information attached to this chunk. */
@@ -254,10 +257,7 @@ export const ModelResponseChunkSchema = z.object({
254257
});
255258
export type ModelResponseChunkData = z.infer<typeof ModelResponseChunkSchema>;
256259

257-
export const GenerateResponseChunkSchema = ModelResponseChunkSchema.extend({
258-
/** @deprecated The index of the candidate this chunk belongs to. Always 0. */
259-
index: z.number(),
260-
});
260+
export const GenerateResponseChunkSchema = ModelResponseChunkSchema.extend({});
261261
export type GenerateResponseChunkData = z.infer<
262262
typeof GenerateResponseChunkSchema
263263
>;

genkit-tools/genkit-schema.json

+13-5
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,12 @@
452452
"GenerateResponseChunk": {
453453
"type": "object",
454454
"properties": {
455+
"role": {
456+
"$ref": "#/$defs/Role"
457+
},
458+
"index": {
459+
"type": "number"
460+
},
455461
"content": {
456462
"type": "array",
457463
"items": {
@@ -461,14 +467,10 @@
461467
"custom": {},
462468
"aggregated": {
463469
"type": "boolean"
464-
},
465-
"index": {
466-
"type": "number"
467470
}
468471
},
469472
"required": [
470-
"content",
471-
"index"
473+
"content"
472474
],
473475
"additionalProperties": false
474476
},
@@ -714,6 +716,12 @@
714716
"ModelResponseChunk": {
715717
"type": "object",
716718
"properties": {
719+
"role": {
720+
"$ref": "#/$defs/GenerateResponseChunk/properties/role"
721+
},
722+
"index": {
723+
"$ref": "#/$defs/GenerateResponseChunk/properties/index"
724+
},
717725
"content": {
718726
"$ref": "#/$defs/GenerateResponseChunk/properties/content"
719727
},

js/ai/src/generate/action.ts

+40-13
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,11 @@ export async function generateHelper(
8383
registry: Registry,
8484
input: z.infer<typeof GenerateUtilParamSchema>,
8585
middleware?: ModelMiddleware[],
86-
currentTurns?: number
86+
currentTurns?: number,
87+
messageIndex?: number
8788
): Promise<GenerateResponseData> {
8889
currentTurns = currentTurns ?? 0;
90+
messageIndex = messageIndex ?? 0;
8991
// do tracing
9092
return await runInNewSpan(
9193
registry,
@@ -100,7 +102,13 @@ export async function generateHelper(
100102
async (metadata) => {
101103
metadata.name = 'generate';
102104
metadata.input = input;
103-
const output = await generate(registry, input, middleware, currentTurns!);
105+
const output = await generate(
106+
registry,
107+
input,
108+
middleware,
109+
currentTurns!,
110+
messageIndex!
111+
);
104112
metadata.output = JSON.stringify(output);
105113
return output;
106114
}
@@ -111,7 +119,8 @@ async function generate(
111119
registry: Registry,
112120
rawRequest: z.infer<typeof GenerateUtilParamSchema>,
113121
middleware: ModelMiddleware[] | undefined,
114-
currentTurn: number
122+
currentTurn: number,
123+
messageIndex: number
115124
): Promise<GenerateResponseData> {
116125
const { modelAction: model } = await resolveModel(registry, rawRequest.model);
117126
if (model.__action.metadata?.model.stage === 'deprecated') {
@@ -152,15 +161,17 @@ async function generate(
152161
streamingCallback
153162
? (chunk: GenerateResponseChunkData) => {
154163
// Store accumulated chunk data
155-
streamingCallback(
156-
new GenerateResponseChunk(chunk, {
157-
index: 0,
158-
role: 'model',
159-
previousChunks: accumulatedChunks,
160-
parser: resolvedFormat?.handler(request.output?.schema)
161-
.parseChunk,
162-
})
163-
);
164+
if (streamingCallback) {
165+
streamingCallback!(
166+
new GenerateResponseChunk(chunk, {
167+
index: messageIndex,
168+
role: 'model',
169+
previousChunks: accumulatedChunks,
170+
parser: resolvedFormat?.handler(request.output?.schema)
171+
.parseChunk,
172+
})
173+
);
174+
}
164175
accumulatedChunks.push(chunk);
165176
}
166177
: undefined,
@@ -246,6 +257,7 @@ async function generate(
246257
});
247258
}
248259
}
260+
messageIndex++;
249261
const nextRequest = {
250262
...rawRequest,
251263
messages: [
@@ -257,11 +269,26 @@ async function generate(
257269
] as MessageData[],
258270
tools: newTools,
259271
};
272+
// stream out the tool responses
273+
streamingCallback?.(
274+
new GenerateResponseChunk(
275+
{
276+
content: toolResponses,
277+
},
278+
{
279+
index: messageIndex,
280+
role: 'model',
281+
previousChunks: accumulatedChunks,
282+
parser: resolvedFormat?.handler(request.output?.schema).parseChunk,
283+
}
284+
)
285+
);
260286
return await generateHelper(
261287
registry,
262288
nextRequest,
263289
middleware,
264-
currentTurn + 1
290+
currentTurn + 1,
291+
messageIndex + 1
265292
);
266293
}
267294

js/ai/src/generate/chunk.ts

+18-10
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ export class GenerateResponseChunk<T = unknown>
3131
implements GenerateResponseChunkData
3232
{
3333
/** The index of the message this chunk corresponds to, starting with `0` for the first model response of the generation. */
34-
index?: number;
34+
index: number;
3535
/** The role of the message this chunk corresponds to. Will always be `model` or `tool`. */
36-
role?: Role;
36+
role: Role;
3737
/** The content generated in this chunk. */
3838
content: Part[];
3939
/** Custom model-specific data for this chunk. */
@@ -45,21 +45,21 @@ export class GenerateResponseChunk<T = unknown>
4545

4646
constructor(
4747
data: GenerateResponseChunkData,
48-
options?: {
48+
options: {
4949
previousChunks?: GenerateResponseChunkData[];
50-
role?: Role;
51-
index?: number;
50+
role: Role;
51+
index: number;
5252
parser?: ChunkParser<T>;
5353
}
5454
) {
5555
this.content = data.content || [];
5656
this.custom = data.custom;
57-
this.previousChunks = options?.previousChunks
57+
this.previousChunks = options.previousChunks
5858
? [...options.previousChunks]
5959
: undefined;
60-
this.index = options?.index;
61-
this.role = options?.role;
62-
this.parser = options?.parser;
60+
this.index = options.index;
61+
this.role = options.role;
62+
this.parser = options.parser;
6363
}
6464

6565
/**
@@ -130,6 +130,14 @@ export class GenerateResponseChunk<T = unknown>
130130
}
131131

132132
toJSON(): GenerateResponseChunkData {
133-
return { content: this.content, custom: this.custom };
133+
const data = {
134+
role: this.role,
135+
index: this.index,
136+
content: this.content,
137+
} as GenerateResponseChunkData;
138+
if (this.custom) {
139+
data.custom = this.custom;
140+
}
141+
return data;
134142
}
135143
}

js/ai/src/model.ts

+4-4
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,9 @@ export type GenerateResponseData = z.infer<typeof GenerateResponseSchema>;
400400

401401
/** ModelResponseChunkSchema represents a chunk of content to stream to the client. */
402402
export const ModelResponseChunkSchema = z.object({
403+
role: RoleSchema.optional(),
404+
/** index of the message this chunk belongs to. */
405+
index: z.number().optional(),
403406
/** The chunk of content to stream right now. */
404407
content: z.array(PartSchema),
405408
/** Model-specific extra information attached to this chunk. */
@@ -409,10 +412,7 @@ export const ModelResponseChunkSchema = z.object({
409412
});
410413
export type ModelResponseChunkData = z.infer<typeof ModelResponseChunkSchema>;
411414

412-
export const GenerateResponseChunkSchema = ModelResponseChunkSchema.extend({
413-
/** @deprecated The index of the candidate this chunk belongs to. Always 0. */
414-
index: z.number().optional(),
415-
});
415+
export const GenerateResponseChunkSchema = ModelResponseChunkSchema;
416416
export type GenerateResponseChunkData = z.infer<
417417
typeof GenerateResponseChunkSchema
418418
>;

js/ai/tests/formats/array_test.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,17 @@ describe('arrayFormat', () => {
7171

7272
for (const chunk of st.chunks) {
7373
const newChunk: GenerateResponseChunkData = {
74+
index: 0,
75+
role: 'model',
7476
content: [{ text: chunk.text }],
7577
};
7678

7779
const result = parser.parseChunk!(
78-
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
80+
new GenerateResponseChunk(newChunk, {
81+
index: 0,
82+
role: 'model',
83+
previousChunks: chunks,
84+
})
7985
);
8086
chunks.push(newChunk);
8187

js/ai/tests/formats/json_test.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,17 @@ describe('jsonFormat', () => {
6767

6868
for (const chunk of st.chunks) {
6969
const newChunk: GenerateResponseChunkData = {
70+
index: 0,
71+
role: 'model',
7072
content: [{ text: chunk.text }],
7173
};
7274

7375
const result = parser.parseChunk!(
74-
new GenerateResponseChunk(newChunk, { previousChunks: [...chunks] })
76+
new GenerateResponseChunk(newChunk, {
77+
index: 0,
78+
role: 'model',
79+
previousChunks: [...chunks],
80+
})
7581
);
7682
chunks.push(newChunk);
7783

js/ai/tests/formats/jsonl_test.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,17 @@ describe('jsonlFormat', () => {
8080

8181
for (const chunk of st.chunks) {
8282
const newChunk: GenerateResponseChunkData = {
83+
index: 0,
84+
role: 'model',
8385
content: [{ text: chunk.text }],
8486
};
8587

8688
const result = parser.parseChunk!(
87-
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
89+
new GenerateResponseChunk(newChunk, {
90+
index: 0,
91+
role: 'model',
92+
previousChunks: chunks,
93+
})
8894
);
8995
chunks.push(newChunk);
9096

js/ai/tests/formats/text_test.ts

+7-1
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,17 @@ describe('textFormat', () => {
5454

5555
for (const chunk of st.chunks) {
5656
const newChunk: GenerateResponseChunkData = {
57+
index: 0,
58+
role: 'model',
5759
content: [{ text: chunk.text }],
5860
};
5961

6062
const result = parser.parseChunk!(
61-
new GenerateResponseChunk(newChunk, { previousChunks: chunks })
63+
new GenerateResponseChunk(newChunk, {
64+
index: 0,
65+
role: 'model',
66+
previousChunks: chunks,
67+
})
6268
);
6369
chunks.push(newChunk);
6470

js/ai/tests/generate/chunk_test.ts

+3-3
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@ import { GenerateResponseChunk } from '../../src/generate.js';
2121
describe('GenerateResponseChunk', () => {
2222
describe('text accumulation', () => {
2323
const testChunk = new GenerateResponseChunk(
24-
{ content: [{ text: 'new' }] },
24+
{ index: 0, role: 'model', content: [{ text: 'new' }] },
2525
{
2626
previousChunks: [
27-
{ content: [{ text: 'old1' }] },
28-
{ content: [{ text: 'old2' }] },
27+
{ index: 0, role: 'model', content: [{ text: 'old1' }] },
28+
{ index: 0, role: 'model', content: [{ text: 'old2' }] },
2929
],
3030
}
3131
);

js/ai/tests/generate/generate_test.ts

+15-4
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ describe('generate', () => {
369369
});
370370

371371
describe('generateStream', () => {
372-
it('should pass a smoke test', async () => {
372+
it('should stream out chunks', async () => {
373373
let registry = new Registry();
374374

375375
defineModel(
@@ -390,11 +390,22 @@ describe('generate', () => {
390390
prompt: 'Testing streaming',
391391
});
392392

393-
let streamed: string[] = [];
393+
let streamed: any[] = [];
394394
for await (const chunk of stream) {
395-
streamed.push(chunk.text);
395+
streamed.push(chunk.toJSON());
396396
}
397-
assert.deepEqual(streamed, ['hello, ', 'world!']);
397+
assert.deepStrictEqual(streamed, [
398+
{
399+
index: 0,
400+
role: 'model',
401+
content: [{ text: 'hello, ' }],
402+
},
403+
{
404+
index: 0,
405+
role: 'model',
406+
content: [{ text: 'world!' }],
407+
},
408+
]);
398409
assert.deepEqual(
399410
(await response).messages.map((m) => m.content[0].text),
400411
['Testing streaming', 'Testing streaming']

0 commit comments

Comments
 (0)