Skip to content

Commit 3efcc3e

Browse files
committed
Handle processing chunked event streams
1 parent 913f524 commit 3efcc3e

File tree

2 files changed

+310
-6
lines changed

2 files changed

+310
-6
lines changed

index.test.ts

+286-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import Replicate, {
88
} from "replicate";
99
import nock from "nock";
1010
import fetch from "cross-fetch";
11+
import { Stream } from "./lib/stream";
12+
import { PassThrough } from "node:stream";
1113

1214
let client: Replicate;
1315
const BASE_URL = "https://api.replicate.com/v1";
@@ -251,7 +253,7 @@ describe("Replicate client", () => {
251253
let actual: Record<string, any> | undefined;
252254
nock(BASE_URL)
253255
.post("/predictions")
254-
.reply(201, (uri: string, body: Record<string, any>) => {
256+
.reply(201, (_uri: string, body: Record<string, any>) => {
255257
actual = body;
256258
return body;
257259
});
@@ -1010,8 +1012,6 @@ describe("Replicate client", () => {
10101012
});
10111013

10121014
test("Calls the correct API routes for a model", async () => {
1013-
const firstPollingRequest = true;
1014-
10151015
nock(BASE_URL)
10161016
.post("/models/replicate/hello-world/predictions")
10171017
.reply(201, {
@@ -1179,12 +1179,294 @@ describe("Replicate client", () => {
11791179
// This is a test secret and should not be used in production
11801180
const secret = "whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw";
11811181

1182-
const isValid = await validateWebhook(request, secret);
1182+
const isValid = validateWebhook(request, secret);
11831183
expect(isValid).toBe(true);
11841184
});
11851185

11861186
// Add more tests for error handling, edge cases, etc.
11871187
});
11881188

11891189
// Continue with tests for other methods
1190+
1191+
describe("Stream", () => {
1192+
function createStream(body: string | NodeJS.ReadableStream) {
1193+
const streamEndpoint = "https://stream.replicate.com";
1194+
nock(streamEndpoint)
1195+
.get("/fake_stream")
1196+
.matchHeader("Accept", "text/event-stream")
1197+
.reply(200, body);
1198+
1199+
return new Stream({ url: `${streamEndpoint}/fake_stream`, fetch });
1200+
}
1201+
1202+
test("consumes a server sent event stream", async () => {
1203+
const stream = createStream(
1204+
`
1205+
event: output
1206+
id: EVENT_1
1207+
data: hello world
1208+
1209+
event: done
1210+
id: EVENT_2
1211+
data: {}
1212+
`
1213+
.trim()
1214+
.replace(/^[ ]+/gm, "")
1215+
);
1216+
1217+
const iterator = stream[Symbol.asyncIterator]();
1218+
1219+
expect(await iterator.next()).toEqual({
1220+
done: false,
1221+
value: { event: "output", id: "EVENT_1", data: "hello world" },
1222+
});
1223+
expect(await iterator.next()).toEqual({
1224+
done: false,
1225+
value: { event: "done", id: "EVENT_2", data: "{}" },
1226+
});
1227+
expect(await iterator.next()).toEqual({ done: true });
1228+
expect(await iterator.next()).toEqual({ done: true });
1229+
});
1230+
1231+
test("consumes multiple events", async () => {
1232+
const stream = createStream(
1233+
`
1234+
event: output
1235+
id: EVENT_1
1236+
data: hello world
1237+
1238+
event: output
1239+
id: EVENT_2
1240+
data: hello dave
1241+
1242+
event: done
1243+
id: EVENT_3
1244+
data: {}
1245+
`
1246+
.trim()
1247+
.replace(/^[ ]+/gm, "")
1248+
);
1249+
1250+
const iterator = stream[Symbol.asyncIterator]();
1251+
1252+
expect(await iterator.next()).toEqual({
1253+
done: false,
1254+
value: { event: "output", id: "EVENT_1", data: "hello world" },
1255+
});
1256+
expect(await iterator.next()).toEqual({
1257+
done: false,
1258+
value: { event: "output", id: "EVENT_2", data: "hello dave" },
1259+
});
1260+
expect(await iterator.next()).toEqual({
1261+
done: false,
1262+
value: { event: "done", id: "EVENT_3", data: "{}" },
1263+
});
1264+
expect(await iterator.next()).toEqual({ done: true });
1265+
expect(await iterator.next()).toEqual({ done: true });
1266+
});
1267+
1268+
test("ignores unexpected characters", async () => {
1269+
const stream = createStream(
1270+
`
1271+
: hi
1272+
1273+
event: output
1274+
id: EVENT_1
1275+
data: hello world
1276+
1277+
event: done
1278+
id: EVENT_2
1279+
data: {}
1280+
`
1281+
.trim()
1282+
.replace(/^[ ]+/gm, "")
1283+
);
1284+
1285+
const iterator = stream[Symbol.asyncIterator]();
1286+
1287+
expect(await iterator.next()).toEqual({
1288+
done: false,
1289+
value: { event: "output", id: "EVENT_1", data: "hello world" },
1290+
});
1291+
expect(await iterator.next()).toEqual({
1292+
done: false,
1293+
value: { event: "done", id: "EVENT_2", data: "{}" },
1294+
});
1295+
expect(await iterator.next()).toEqual({ done: true });
1296+
expect(await iterator.next()).toEqual({ done: true });
1297+
});
1298+
1299+
test("supports multiple lines of output in a single event", async () => {
1300+
const stream = createStream(
1301+
`
1302+
: hi
1303+
1304+
event: output
1305+
id: EVENT_1
1306+
data: hello,
1307+
data: this is a new line,
1308+
data: and this is a new line too
1309+
1310+
event: done
1311+
id: EVENT_2
1312+
data: {}
1313+
`
1314+
.trim()
1315+
.replace(/^[ ]+/gm, "")
1316+
);
1317+
1318+
const iterator = stream[Symbol.asyncIterator]();
1319+
1320+
expect(await iterator.next()).toEqual({
1321+
done: false,
1322+
value: {
1323+
event: "output",
1324+
id: "EVENT_1",
1325+
data: "hello,\nthis is a new line,\nand this is a new line too",
1326+
},
1327+
});
1328+
expect(await iterator.next()).toEqual({
1329+
done: false,
1330+
value: { event: "done", id: "EVENT_2", data: "{}" },
1331+
});
1332+
expect(await iterator.next()).toEqual({ done: true });
1333+
expect(await iterator.next()).toEqual({ done: true });
1334+
});
1335+
1336+
test("supports the server writing data lines in multiple chunks", async () => {
1337+
const body = new PassThrough();
1338+
const stream = createStream(body);
1339+
1340+
// Create a stream of data chunks split on the pipe character for readability.
1341+
const data = `
1342+
event: output
1343+
id: EVENT_1
1344+
data: hello,|
1345+
data: this is a new line,|
1346+
data: and this is a new line too
1347+
1348+
event: done
1349+
id: EVENT_2
1350+
data: {}
1351+
`
1352+
.trim()
1353+
.replace(/^[ ]+/gm, "");
1354+
1355+
const chunks = data.split("|");
1356+
1357+
// Consume the iterator in parallel to writing it.
1358+
const reading = new Promise((resolve, reject) => {
1359+
(async () => {
1360+
const iterator = stream[Symbol.asyncIterator]();
1361+
expect(await iterator.next()).toEqual({
1362+
done: false,
1363+
value: {
1364+
event: "output",
1365+
id: "EVENT_1",
1366+
data: "hello,\nthis is a new line,\nand this is a new line too",
1367+
},
1368+
});
1369+
expect(await iterator.next()).toEqual({
1370+
done: false,
1371+
value: { event: "done", id: "EVENT_2", data: "{}" },
1372+
});
1373+
expect(await iterator.next()).toEqual({ done: true });
1374+
})().then(resolve, reject);
1375+
});
1376+
1377+
// Write the chunks to the stream at an interval.
1378+
const writing = new Promise((resolve, reject) => {
1379+
(async () => {
1380+
for await (const chunk of chunks) {
1381+
body.write(chunk);
1382+
await new Promise((resolve) => setTimeout(resolve, 1));
1383+
}
1384+
body.end();
1385+
resolve(null);
1386+
})().then(resolve, reject);
1387+
});
1388+
1389+
// Wait for both promises to resolve.
1390+
await Promise.all([reading, writing]);
1391+
});
1392+
1393+
test("supports the server writing data in a complete mess", async () => {
1394+
const body = new PassThrough();
1395+
const stream = createStream(body);
1396+
1397+
// Create a stream of data chunks split on the pipe character for readability.
1398+
const data = `
1399+
: hi
1400+
1401+
ev|ent: output
1402+
id: EVENT_1
1403+
data: hello,
1404+
data: this |is a new line,|
1405+
data: and this is |a new line too
1406+
1407+
event: d|one
1408+
id: EVENT|_2
1409+
data: {}
1410+
`
1411+
.trim()
1412+
.replace(/^[ ]+/gm, "");
1413+
1414+
const chunks = data.split("|");
1415+
1416+
// Consume the iterator in parallel to writing it.
1417+
const reading = new Promise((resolve, reject) => {
1418+
(async () => {
1419+
const iterator = stream[Symbol.asyncIterator]();
1420+
expect(await iterator.next()).toEqual({
1421+
done: false,
1422+
value: {
1423+
event: "output",
1424+
id: "EVENT_1",
1425+
data: "hello,\nthis is a new line,\nand this is a new line too",
1426+
},
1427+
});
1428+
expect(await iterator.next()).toEqual({
1429+
done: false,
1430+
value: { event: "done", id: "EVENT_2", data: "{}" },
1431+
});
1432+
expect(await iterator.next()).toEqual({ done: true });
1433+
})().then(resolve, reject);
1434+
});
1435+
1436+
// Write the chunks to the stream at an interval.
1437+
const writing = new Promise((resolve, reject) => {
1438+
(async () => {
1439+
for await (const chunk of chunks) {
1440+
body.write(chunk);
1441+
await new Promise((resolve) => setTimeout(resolve, 1));
1442+
}
1443+
body.end();
1444+
resolve(null);
1445+
})().then(resolve, reject);
1446+
});
1447+
1448+
// Wait for both promises to resolve.
1449+
await Promise.all([reading, writing]);
1450+
});
1451+
1452+
test("supports ending without a done", async () => {
1453+
const stream = createStream(
1454+
`
1455+
event: output
1456+
id: EVENT_1
1457+
data: hello world
1458+
1459+
`
1460+
.trim()
1461+
.replace(/^[ ]+/gm, "")
1462+
);
1463+
1464+
const iterator = stream[Symbol.asyncIterator]();
1465+
expect(await iterator.next()).toEqual({
1466+
done: false,
1467+
value: { event: "output", id: "EVENT_1", data: "hello world" },
1468+
});
1469+
expect(await iterator.next()).toEqual({ done: true });
1470+
});
1471+
});
11901472
});

lib/stream.js

+24-2
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class Stream extends Readable {
5050
*
5151
* @param {object} config
5252
* @param {string} config.url The URL to connect to.
53-
* @param {Function} [config.fetch] The fetch implemention to use.
53+
* @param {typeof fetch} [config.fetch] The fetch implementation to use.
5454
* @param {object} [config.options] The fetch options.
5555
*/
5656
constructor({ url, fetch = globalThis.fetch, options = {} }) {
@@ -114,10 +114,21 @@ class Stream extends Readable {
114114
},
115115
});
116116

117+
if (!response.ok) {
118+
throw new Error();
119+
}
120+
121+
let partialChunk = "";
117122
for await (const chunk of response.body) {
118123
const decoder = new TextDecoder("utf-8");
119-
const text = decoder.decode(chunk);
124+
const text = partialChunk + decoder.decode(chunk);
120125
const lines = text.split("\n");
126+
127+
// We want to ensure that the last line is not a fragment
128+
// so we keep it and append it to the start of the next
129+
// chunk.
130+
partialChunk = lines.pop();
131+
121132
for (const line of lines) {
122133
const sse = this.decode(line);
123134
if (sse) {
@@ -133,6 +144,17 @@ class Stream extends Readable {
133144
}
134145
}
135146
}
147+
148+
// Process the final line and ensure we have captured the final event.
149+
this.decode(partialChunk);
150+
const sse = this.decode("");
151+
if (sse) {
152+
if (sse.event === "error") {
153+
throw new Error(sse.data);
154+
}
155+
156+
yield sse;
157+
}
136158
}
137159
}
138160

0 commit comments

Comments
 (0)