diff --git a/biome.json b/biome.json index 094cf0e..ba1c236 100644 --- a/biome.json +++ b/biome.json @@ -1,11 +1,7 @@ { "$schema": "https://biomejs.dev/schemas/1.0.0/schema.json", "files": { - "ignore": [ - ".wrangler", - "node_modules", - "vendor/*" - ] + "ignore": [".wrangler", "node_modules", "vendor/*"] }, "formatter": { "indentStyle": "space", diff --git a/index.d.ts b/index.d.ts index 2b183d0..678d7a0 100644 --- a/index.d.ts +++ b/index.d.ts @@ -173,6 +173,7 @@ declare module "replicate" { webhook?: string; webhook_events_filter?: WebhookEventType[]; signal?: AbortSignal; + useFileOutput?: boolean; } ): AsyncGenerator; diff --git a/index.js b/index.js index b1248e7..3fc14dd 100644 --- a/index.js +++ b/index.js @@ -315,7 +315,12 @@ class Replicate { * @yields {ServerSentEvent} Each streamed event from the prediction */ async *stream(ref, options) { - const { wait, signal, ...data } = options; + const { + wait, + signal, + useFileOutput = this.useFileOutput, + ...data + } = options; const identifier = ModelVersionIdentifier.parse(ref); @@ -338,7 +343,10 @@ class Replicate { const stream = createReadableStream({ url: prediction.urls.stream, fetch: this.fetch, - ...(signal ? { options: { signal } } : {}), + options: { + useFileOutput, + ...(signal ? { signal } : {}), + }, }); yield* streamAsyncIterator(stream); diff --git a/index.test.ts b/index.test.ts index 4905908..65eb93e 100644 --- a/index.test.ts +++ b/index.test.ts @@ -1906,8 +1906,12 @@ describe("Replicate client", () => { // Continue with tests for other methods describe("createReadableStream", () => { - function createStream(body: string | ReadableStream, status = 200) { - const streamEndpoint = "https://stream.replicate.com/fake_stream"; + function createStream( + body: string | ReadableStream, + status = 200, + streamEndpoint = "https://stream.replicate.com/fake_stream", + options: { useFileOutput?: boolean } = {} + ) { const fetch = jest.fn((url) => { if (url !== streamEndpoint) { throw new Error(`Unmocked call to fetch() with url: ${url}`); @@ -1917,6 +1921,7 @@ describe("Replicate client", () => { return createReadableStream({ url: streamEndpoint, fetch: fetch as any, + options, }); } @@ -2193,5 +2198,95 @@ describe("Replicate client", () => { ); expect(await iterator.next()).toEqual({ done: true }); }); + + describe("file streams", () => { + test("emits FileOutput objects", async () => { + const stream = createStream( + ` + event: output + id: EVENT_1 + data:  + + event: output + id: EVENT_2 + data: https://delivery.replicate.com/my_file.png + + event: done + id: EVENT_3 + data: {} + + `.replace(/^[ ]+/gm, ""), + 200, + "https://stream.replicate.com/v1/files/abcd" + ); + + const iterator = stream[Symbol.asyncIterator](); + const { value: event1 } = await iterator.next(); + expect(event1.data).toBeInstanceOf(ReadableStream); + expect(event1.data.url().href).toEqual( + "" + ); + + const { value: event2 } = await iterator.next(); + expect(event2.data).toBeInstanceOf(ReadableStream); + expect(event2.data.url().href).toEqual( + "https://delivery.replicate.com/my_file.png" + ); + + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_3", data: "{}" }, + }); + + expect(await iterator.next()).toEqual({ done: true }); + }); + + test("emits strings when useFileOutput is false", async () => { + const stream = createStream( + ` + event: output + id: EVENT_1 + data:  + + event: output + id: EVENT_2 + data: https://delivery.replicate.com/my_file.png + + event: done + id: EVENT_3 + data: {} + + `.replace(/^[ ]+/gm, ""), + 200, + "https://stream.replicate.com/v1/files/abcd", + { useFileOutput: false } + ); + + const iterator = stream[Symbol.asyncIterator](); + + expect(await iterator.next()).toEqual({ + done: false, + value: { + event: "output", + id: "EVENT_1", + data: "", + }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { + event: "output", + id: "EVENT_2", + data: "https://delivery.replicate.com/my_file.png", + }, + }); + expect(await iterator.next()).toEqual({ + done: false, + value: { event: "done", id: "EVENT_3", data: "{}" }, + }); + + expect(await iterator.next()).toEqual({ done: true }); + }); + }); }); }); diff --git a/lib/stream.js b/lib/stream.js index 2c899bd..802a98e 100644 --- a/lib/stream.js +++ b/lib/stream.js @@ -53,6 +53,7 @@ class ServerSentEvent { */ function createReadableStream({ url, fetch, options = {} }) { const { useFileOutput = true, headers = {}, ...initOptions } = options; + const shouldProcessFileOutput = useFileOutput && isFileStream(url); return new ReadableStream({ async start(controller) { @@ -89,11 +90,11 @@ function createReadableStream({ url, fetch, options = {} }) { let data = event.data; if ( - useFileOutput && - typeof data === "string" && - (data.startsWith("https:") || data.startsWith("data:")) + event.event === "output" && + shouldProcessFileOutput && + typeof data === "string" ) { - data = createFileOutput({ data, fetch }); + data = createFileOutput({ url: data, fetch }); } controller.enqueue(new ServerSentEvent(event.event, data, event.id)); @@ -169,6 +170,13 @@ function createFileOutput({ url, fetch }) { }); } +function isFileStream(url) { + try { + return new URL(url).pathname.startsWith("/v1/files/"); + } catch {} + return false; +} + module.exports = { createFileOutput, createReadableStream,