Skip to content

Commit 82d9056

Browse files
authoredMar 12, 2024··
Close the stream when receiving "done" event from the server (#219)
This fixes a regression introduced with #214 where we were not exiting correctly when getting the `"done"` event from the server. This was picked up by the introduction of the CloudFlare integration tests added in #217 which uses the streaming API. Once the fix was added it turns out that the `nock()` tests were incorrectly passing due to some internal weirdness when using `respondWith` and a `Readable` object. I wasn't able to get this working without hitting a different error: TypeError: Invalid state: Controller is already closed It looks like nock is retaining some global state somewhere in it's implementation and streams are being retained across requests. No combination of resetting mocks seemed to fix it. In the end I just mocked out the fetch function passed into the `createReadableStream` library and returned a `Response`. I think we should probably do this everywhere rather than use `nock()` as the Request/Response APIs provided by fetch are much better now than the old node http lib.
1 parent 3c62031 commit 82d9056

File tree

7 files changed

+85
-89
lines changed

7 files changed

+85
-89
lines changed
 

‎index.test.ts

+55-75
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@ import Replicate, {
77
parseProgressFromLogs,
88
} from "replicate";
99
import nock from "nock";
10+
import { Readable } from "node:stream";
1011
import { createReadableStream } from "./lib/stream";
11-
import { PassThrough } from "node:stream";
1212

1313
let client: Replicate;
1414
const BASE_URL = "https://api.replicate.com/v1";
@@ -1187,16 +1187,17 @@ describe("Replicate client", () => {
11871187
// Continue with tests for other methods
11881188

11891189
describe("createReadableStream", () => {
1190-
function createStream(body: string | NodeJS.ReadableStream, status = 200) {
1191-
const streamEndpoint = "https://stream.replicate.com";
1192-
nock(streamEndpoint)
1193-
.get("/fake_stream")
1194-
.matchHeader("Accept", "text/event-stream")
1195-
.reply(status, body);
1196-
1190+
function createStream(body: string | ReadableStream, status = 200) {
1191+
const streamEndpoint = "https://stream.replicate.com/fake_stream";
1192+
const fetch = jest.fn((url) => {
1193+
if (url !== streamEndpoint) {
1194+
throw new Error(`Unmocked call to fetch() with url: ${url}`);
1195+
}
1196+
return new Response(body, { status });
1197+
});
11971198
return createReadableStream({
1198-
url: `${streamEndpoint}/fake_stream`,
1199-
fetch: fetch,
1199+
url: streamEndpoint,
1200+
fetch: fetch as any,
12001201
});
12011202
}
12021203

@@ -1330,9 +1331,6 @@ describe("Replicate client", () => {
13301331
});
13311332

13321333
test("supports the server writing data lines in multiple chunks", async () => {
1333-
const body = new PassThrough();
1334-
const stream = createStream(body);
1335-
13361334
// Create a stream of data chunks split on the pipe character for readability.
13371335
const data = `
13381336
event: output
@@ -1348,45 +1346,47 @@ describe("Replicate client", () => {
13481346
`.replace(/^[ ]+/gm, "");
13491347

13501348
const chunks = data.split("|");
1349+
const body = new ReadableStream({
1350+
async pull(controller) {
1351+
if (chunks.length) {
1352+
await new Promise((resolve) => setTimeout(resolve, 1));
1353+
const chunk = chunks.shift();
1354+
controller.enqueue(new TextEncoder().encode(chunk));
1355+
}
1356+
},
1357+
});
1358+
1359+
const stream = createStream(body);
13511360

13521361
// Consume the iterator in parallel to writing it.
1353-
const reading = new Promise((resolve, reject) => {
1354-
(async () => {
1355-
const iterator = stream[Symbol.asyncIterator]();
1356-
expect(await iterator.next()).toEqual({
1357-
done: false,
1358-
value: {
1359-
event: "output",
1360-
id: "EVENT_1",
1361-
data: "hello,\nthis is a new line,\nand this is a new line too",
1362-
},
1363-
});
1364-
expect(await iterator.next()).toEqual({
1365-
done: false,
1366-
value: { event: "done", id: "EVENT_2", data: "{}" },
1367-
});
1368-
expect(await iterator.next()).toEqual({ done: true });
1369-
})().then(resolve, reject);
1362+
const iterator = stream[Symbol.asyncIterator]();
1363+
expect(await iterator.next()).toEqual({
1364+
done: false,
1365+
value: {
1366+
event: "output",
1367+
id: "EVENT_1",
1368+
data: "hello,\nthis is a new line,\nand this is a new line too",
1369+
},
13701370
});
1371-
1372-
// Write the chunks to the stream at an interval.
1373-
const writing = new Promise((resolve, reject) => {
1374-
(async () => {
1375-
for await (const chunk of chunks) {
1376-
body.write(chunk);
1377-
await new Promise((resolve) => setTimeout(resolve, 1));
1378-
}
1379-
body.end();
1380-
resolve(null);
1381-
})().then(resolve, reject);
1371+
expect(await iterator.next()).toEqual({
1372+
done: false,
1373+
value: { event: "done", id: "EVENT_2", data: "{}" },
13821374
});
1375+
expect(await iterator.next()).toEqual({ done: true });
13831376

13841377
// Wait for both promises to resolve.
1385-
await Promise.all([reading, writing]);
13861378
});
13871379

13881380
test("supports the server writing data in a complete mess", async () => {
1389-
const body = new PassThrough();
1381+
const body = new ReadableStream({
1382+
async pull(controller) {
1383+
if (chunks.length) {
1384+
await new Promise((resolve) => setTimeout(resolve, 1));
1385+
const chunk = chunks.shift();
1386+
controller.enqueue(new TextEncoder().encode(chunk));
1387+
}
1388+
},
1389+
});
13901390
const stream = createStream(body);
13911391

13921392
// Create a stream of data chunks split on the pipe character for readability.
@@ -1407,40 +1407,20 @@ describe("Replicate client", () => {
14071407

14081408
const chunks = data.split("|");
14091409

1410-
// Consume the iterator in parallel to writing it.
1411-
const reading = new Promise((resolve, reject) => {
1412-
(async () => {
1413-
const iterator = stream[Symbol.asyncIterator]();
1414-
expect(await iterator.next()).toEqual({
1415-
done: false,
1416-
value: {
1417-
event: "output",
1418-
id: "EVENT_1",
1419-
data: "hello,\nthis is a new line,\nand this is a new line too",
1420-
},
1421-
});
1422-
expect(await iterator.next()).toEqual({
1423-
done: false,
1424-
value: { event: "done", id: "EVENT_2", data: "{}" },
1425-
});
1426-
expect(await iterator.next()).toEqual({ done: true });
1427-
})().then(resolve, reject);
1410+
const iterator = stream[Symbol.asyncIterator]();
1411+
expect(await iterator.next()).toEqual({
1412+
done: false,
1413+
value: {
1414+
event: "output",
1415+
id: "EVENT_1",
1416+
data: "hello,\nthis is a new line,\nand this is a new line too",
1417+
},
14281418
});
1429-
1430-
// Write the chunks to the stream at an interval.
1431-
const writing = new Promise((resolve, reject) => {
1432-
(async () => {
1433-
for await (const chunk of chunks) {
1434-
body.write(chunk);
1435-
await new Promise((resolve) => setTimeout(resolve, 1));
1436-
}
1437-
body.end();
1438-
resolve(null);
1439-
})().then(resolve, reject);
1419+
expect(await iterator.next()).toEqual({
1420+
done: false,
1421+
value: { event: "done", id: "EVENT_2", data: "{}" },
14401422
});
1441-
1442-
// Wait for both promises to resolve.
1443-
await Promise.all([reading, writing]);
1423+
expect(await iterator.next()).toEqual({ done: true });
14441424
});
14451425

14461426
test("supports ending without a done", async () => {

‎integration/cloudflare-worker/.npmrc

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
package-lock=false
2-
2+
audit=false
3+
fund=false

‎integration/cloudflare-worker/index.test.js

+10-5
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@ import { unstable_dev as dev } from "wrangler";
33
import { test, after, before, describe } from "node:test";
44
import assert from "node:assert";
55

6-
/** @type {import("wrangler").UnstableDevWorker} */
76
describe("CloudFlare Worker", () => {
7+
/** @type {import("wrangler").UnstableDevWorker} */
88
let worker;
99

1010
before(async () => {
@@ -22,15 +22,20 @@ describe("CloudFlare Worker", () => {
2222
await worker.stop();
2323
});
2424

25-
test("worker streams back a response", { timeout: 1000 }, async () => {
25+
test("worker streams back a response", { timeout: 5000 }, async () => {
2626
const resp = await worker.fetch();
2727
const text = await resp.text();
2828

29-
assert.ok(resp.ok, "status is 2xx");
30-
assert(text.length > 0, "body.length is greater than 0");
29+
assert.ok(resp.ok, `expected status to be 2xx but got ${resp.status}`);
30+
assert(
31+
text.length > 0,
32+
"expected body to have content but got body.length of 0"
33+
);
3134
assert(
3235
text.includes("Colin CloudFlare"),
33-
"body includes stream characters"
36+
`expected body to include "Colin CloudFlare" but got ${JSON.stringify(
37+
text
38+
)}`
3439
);
3540
});
3641
});

‎integration/commonjs/.npmrc

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
package-lock=false
2-
2+
audit=false
3+
fund=false

‎integration/esm/.npmrc

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
package-lock=false
2-
2+
audit=false
3+
fund=false

‎integration/typescript/.npmrc

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
package-lock=false
2-
2+
audit=false
3+
fund=false

‎lib/stream.js

+12-5
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ function createReadableStream({ url, fetch, options = {} }) {
6262
const request = new Request(url, init);
6363
controller.error(
6464
new ApiError(
65-
`Request to ${url} failed with status ${response.status}`,
65+
`Request to ${url} failed with status ${response.status}: ${text}`,
6666
request,
6767
response
6868
)
@@ -72,15 +72,22 @@ function createReadableStream({ url, fetch, options = {} }) {
7272
const stream = response.body
7373
.pipeThrough(new TextDecoderStream())
7474
.pipeThrough(new EventSourceParserStream());
75+
7576
for await (const event of stream) {
7677
if (event.event === "error") {
7778
controller.error(new Error(event.data));
78-
} else {
79-
controller.enqueue(
80-
new ServerSentEvent(event.event, event.data, event.id)
81-
);
79+
break;
80+
}
81+
82+
controller.enqueue(
83+
new ServerSentEvent(event.event, event.data, event.id)
84+
);
85+
86+
if (event.event === "done") {
87+
break;
8288
}
8389
}
90+
8491
controller.close();
8592
},
8693
});

0 commit comments

Comments
 (0)
Please sign in to comment.