Skip to content

Commit bae3a2c

Browse files
authored
refactor(js/ai/chat): make sendStream sync (#1732)
1 parent d53b930 commit bae3a2c

File tree

3 files changed

+46
-76
lines changed

3 files changed

+46
-76
lines changed

js/ai/src/chat.ts

Lines changed: 42 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
*/
1616

1717
import { StreamingCallback, z } from '@genkit-ai/core';
18+
import { Channel } from '@genkit-ai/core/async';
1819
import {
1920
ATTR_PREFIX,
2021
SPAN_TYPE_ATTR,
@@ -30,7 +31,6 @@ import {
3031
MessageData,
3132
Part,
3233
generate,
33-
generateStream,
3434
} from './index.js';
3535
import {
3636
BaseGenerateOptions,
@@ -154,23 +154,12 @@ export class Chat {
154154
},
155155
},
156156
async (metadata) => {
157-
let resolvedOptions: ChatGenerateOptions<O, CustomOptions>;
157+
let resolvedOptions = resolveSendOptions(options);
158158
let streamingCallback:
159159
| StreamingCallback<GenerateResponseChunk>
160160
| undefined = undefined;
161161

162-
// string
163-
if (typeof options === 'string') {
164-
resolvedOptions = {
165-
prompt: options,
166-
} as ChatGenerateOptions<O, CustomOptions>;
167-
} else if (Array.isArray(options)) {
168-
// Part[]
169-
resolvedOptions = {
170-
prompt: options,
171-
} as ChatGenerateOptions<O, CustomOptions>;
172-
} else {
173-
resolvedOptions = options as ChatGenerateOptions<O, CustomOptions>;
162+
if (resolvedOptions.onChunk || resolvedOptions.streamingCallback) {
174163
streamingCallback =
175164
resolvedOptions.onChunk ?? resolvedOptions.streamingCallback;
176165
}
@@ -204,66 +193,23 @@ export class Chat {
204193
CustomOptions extends z.ZodTypeAny = typeof GenerationCommonConfigSchema,
205194
>(
206195
options: string | Part[] | GenerateStreamOptions<O, CustomOptions>
207-
): Promise<GenerateStreamResponse<z.infer<O>>> {
208-
return runWithSession(this.session.registry, this.session, () =>
209-
runInNewSpan(
210-
this.session.registry,
211-
{
212-
metadata: { name: 'send' },
213-
labels: {
214-
[SPAN_TYPE_ATTR]: 'helper',
215-
[SESSION_ID_ATTR]: this.session.id,
216-
[THREAD_NAME_ATTR]: this.threadName,
217-
},
218-
},
219-
async (metadata) => {
220-
let resolvedOptions;
221-
222-
// string
223-
if (typeof options === 'string') {
224-
resolvedOptions = {
225-
prompt: options,
226-
} as GenerateStreamOptions<O, CustomOptions>;
227-
} else if (Array.isArray(options)) {
228-
// Part[]
229-
resolvedOptions = {
230-
prompt: options,
231-
} as GenerateStreamOptions<O, CustomOptions>;
232-
} else {
233-
resolvedOptions = options as GenerateStreamOptions<
234-
O,
235-
CustomOptions
236-
>;
237-
}
238-
metadata.input = resolvedOptions;
239-
240-
const { response, stream } = await generateStream(
241-
this.session.registry,
242-
{
243-
...(await this.requestBase),
244-
messages: this.messages,
245-
...resolvedOptions,
246-
}
247-
);
196+
): GenerateStreamResponse<z.infer<O>> {
197+
let channel = new Channel<GenerateResponseChunk>();
198+
let resolvedOptions = resolveSendOptions(options);
248199

249-
return {
250-
response: response.finally(async () => {
251-
const resolvedResponse = await response;
252-
this.requestBase = Promise.resolve({
253-
...(await this.requestBase),
254-
// these things may get changed by tools calling within generate.
255-
tools: resolvedResponse?.request?.tools,
256-
toolChoice: resolvedResponse?.request?.toolChoice,
257-
config: resolvedResponse?.request?.config,
258-
});
259-
this.updateMessages(resolvedResponse.messages);
260-
metadata.output = JSON.stringify(resolvedResponse);
261-
}),
262-
stream,
263-
};
264-
}
265-
)
200+
const sent = this.send({
201+
...resolvedOptions,
202+
onChunk: (chunk) => channel.send(chunk),
203+
});
204+
sent.then(
205+
() => channel.close(),
206+
(err) => channel.error(err)
266207
);
208+
209+
return {
210+
response: sent,
211+
stream: channel,
212+
};
267213
}
268214

269215
get messages(): MessageData[] {
@@ -287,3 +233,27 @@ function getPreamble(msgs?: MessageData[]) {
287233
function stripPreamble(msgs?: MessageData[]) {
288234
return msgs?.filter((m) => !m.metadata?.preamble);
289235
}
236+
237+
function resolveSendOptions<
238+
O extends z.ZodTypeAny,
239+
CustomOptions extends z.ZodTypeAny,
240+
>(
241+
options: string | Part[] | ChatGenerateOptions<O, CustomOptions>
242+
): ChatGenerateOptions<O, CustomOptions> {
243+
let resolvedOptions: ChatGenerateOptions<O, CustomOptions>;
244+
245+
// string
246+
if (typeof options === 'string') {
247+
resolvedOptions = {
248+
prompt: options,
249+
} as ChatGenerateOptions<O, CustomOptions>;
250+
} else if (Array.isArray(options)) {
251+
// Part[]
252+
resolvedOptions = {
253+
prompt: options,
254+
} as ChatGenerateOptions<O, CustomOptions>;
255+
} else {
256+
resolvedOptions = options as ChatGenerateOptions<O, CustomOptions>;
257+
}
258+
return resolvedOptions;
259+
}

js/genkit/tests/chat_test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ describe('chat', () => {
6666

6767
it('maintains history in the session with streaming', async () => {
6868
const chat = ai.chat();
69-
let { response, stream } = await chat.sendStream('hi');
69+
let { response, stream } = chat.sendStream('hi');
7070

7171
let chunks: string[] = [];
7272
for await (const chunk of stream) {
@@ -75,7 +75,7 @@ describe('chat', () => {
7575
assert.strictEqual((await response).text, 'Echo: hi; config: {}');
7676
assert.deepStrictEqual(chunks, ['3', '2', '1']);
7777

78-
({ response, stream } = await chat.sendStream('bye'));
78+
({ response, stream } = chat.sendStream('bye'));
7979

8080
chunks = [];
8181
for await (const chunk of stream) {

js/genkit/tests/session_test.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ describe('session', () => {
170170
const session = ai.createSession();
171171
const chat = session.chat();
172172

173-
let { response, stream } = await chat.sendStream('hi');
173+
let { response, stream } = chat.sendStream('hi');
174174

175175
let chunks: string[] = [];
176176
for await (const chunk of stream) {
@@ -179,7 +179,7 @@ describe('session', () => {
179179
assert.strictEqual((await response).text, 'Echo: hi; config: {}');
180180
assert.deepStrictEqual(chunks, ['3', '2', '1']);
181181

182-
({ response, stream } = await chat.sendStream('bye'));
182+
({ response, stream } = chat.sendStream('bye'));
183183

184184
chunks = [];
185185
for await (const chunk of stream) {

0 commit comments

Comments
 (0)