Skip to content

Commit 6ad79c8

Browse files
authored
Enable FileObject and blocking mode by default (#316)
* Enable FileObject and blocking mode by default * 1.0.0-beta.1
1 parent abe1029 commit 6ad79c8

File tree

5 files changed

+34
-23
lines changed

5 files changed

+34
-23
lines changed

index.d.ts

+5-4
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,9 @@ declare module "replicate" {
156156
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
157157
options: {
158158
input: object;
159-
wait?: boolean | number | { interval?: number };
159+
wait?:
160+
| { mode: "block"; interval?: number; timeout?: number }
161+
| { mode: "poll"; interval?: number };
160162
webhook?: string;
161163
webhook_events_filter?: WebhookEventType[];
162164
signal?: AbortSignal;
@@ -189,7 +191,6 @@ declare module "replicate" {
189191
wait(
190192
prediction: Prediction,
191193
options?: {
192-
mode?: "poll";
193194
interval?: number;
194195
},
195196
stop?: (prediction: Prediction) => Promise<boolean>
@@ -215,7 +216,7 @@ declare module "replicate" {
215216
stream?: boolean;
216217
webhook?: string;
217218
webhook_events_filter?: WebhookEventType[];
218-
wait?: boolean | number | { mode?: "poll"; interval?: number };
219+
wait?: number | boolean;
219220
}
220221
): Promise<Prediction>;
221222
};
@@ -304,7 +305,7 @@ declare module "replicate" {
304305
stream?: boolean;
305306
webhook?: string;
306307
webhook_events_filter?: WebhookEventType[];
307-
wait?: boolean | number | { mode?: "poll"; interval?: number };
308+
wait?: boolean | number;
308309
} & ({ version: string } | { model: string })
309310
): Promise<Prediction>;
310311
get(prediction_id: string): Promise<Prediction>;

index.js

+8-10
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Replicate {
4848
* @param {string} options.userAgent - Identifier of your app
4949
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
5050
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
51-
* @param {boolean} [options.useFileOutput] - Set to `true` to return `FileOutput` objects from `run` instead of URLs, defaults to false.
51+
* @param {boolean} [options.useFileOutput] - Set to `false` to disable `FileOutput` objects from `run` instead of URLs, defaults to true.
5252
* @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use
5353
*/
5454
constructor(options = {}) {
@@ -60,7 +60,7 @@ class Replicate {
6060
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
6161
this.fetch = options.fetch || globalThis.fetch;
6262
this.fileEncodingStrategy = options.fileEncodingStrategy || "default";
63-
this.useFileOutput = options.useFileOutput || false;
63+
this.useFileOutput = options.useFileOutput === false ? false : true;
6464

6565
this.accounts = {
6666
current: accounts.current.bind(this),
@@ -133,8 +133,7 @@ class Replicate {
133133
* @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
134134
* @param {object} options
135135
* @param {object} options.input - Required. An object with the model inputs
136-
* @param {object} [options.wait] - Options for waiting for the prediction to finish. If `wait` is explicitly true, the function will block and wait for the prediction to finish.
137-
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500
136+
* @param {{mode: "block", timeout?: number, interval?: number} | {mode: "poll", interval?: number }} [options.wait] - Options for waiting for the prediction to finish. If `wait` is explicitly true, the function will block and wait for the prediction to finish.
138137
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
139138
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
140139
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
@@ -144,23 +143,22 @@ class Replicate {
144143
* @returns {Promise<object>} - Resolves with the output of running the model
145144
*/
146145
async run(ref, options, progress) {
147-
const { wait, signal, ...data } = options;
146+
const { wait = { mode: "block" }, signal, ...data } = options;
148147

149148
const identifier = ModelVersionIdentifier.parse(ref);
150-
const isBlocking = typeof wait === "boolean" || typeof wait === "number";
151149

152150
let prediction;
153151
if (identifier.version) {
154152
prediction = await this.predictions.create({
155153
...data,
156154
version: identifier.version,
157-
wait: isBlocking ? wait : false,
155+
wait: wait.mode === "block" ? wait.timeout ?? true : false,
158156
});
159157
} else if (identifier.owner && identifier.name) {
160158
prediction = await this.predictions.create({
161159
...data,
162160
model: `${identifier.owner}/${identifier.name}`,
163-
wait: isBlocking ? wait : false,
161+
wait: wait.mode === "block" ? wait.timeout ?? true : false,
164162
});
165163
} else {
166164
throw new Error("Invalid model version identifier");
@@ -171,11 +169,11 @@ class Replicate {
171169
progress(prediction);
172170
}
173171

174-
const isDone = isBlocking && prediction.status !== "starting";
172+
const isDone = wait.mode === "block" && prediction.status !== "starting";
175173
if (!isDone) {
176174
prediction = await this.wait(
177175
prediction,
178-
isBlocking ? {} : wait,
176+
{ interval: wait.mode === "poll" ? wait.interval : undefined },
179177
async (updatedPrediction) => {
180178
// Call progress callback with the updated prediction object
181179
if (progress) {

index.test.ts

+18-6
Original file line numberDiff line numberDiff line change
@@ -1310,7 +1310,7 @@ describe("Replicate client", () => {
13101310
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
13111311
{
13121312
input: { text: "Hello, world!" },
1313-
wait: { interval: 1 },
1313+
wait: { mode: "poll", interval: 1 },
13141314
},
13151315
(prediction) => {
13161316
const progress = parseProgressFromLogs(prediction);
@@ -1402,7 +1402,7 @@ describe("Replicate client", () => {
14021402
"replicate/hello-world",
14031403
{
14041404
input: { text: "Hello, world!" },
1405-
wait: { interval: 1 },
1405+
wait: { mode: "poll", interval: 1 },
14061406
},
14071407
progress
14081408
);
@@ -1448,12 +1448,18 @@ describe("Replicate client", () => {
14481448
});
14491449

14501450
await expect(
1451-
client.run("a/b-1.0:abc123", { input: { text: "Hello, world!" } })
1451+
client.run("a/b-1.0:abc123", {
1452+
wait: { mode: "poll" },
1453+
input: { text: "Hello, world!" },
1454+
})
14521455
).resolves.not.toThrow();
14531456
});
14541457

14551458
test("Throws an error for invalid identifiers", async () => {
1456-
const options = { input: { text: "Hello, world!" } };
1459+
const options = {
1460+
wait: { mode: "poll" } as { mode: "poll" },
1461+
input: { text: "Hello, world!" },
1462+
};
14571463

14581464
// @ts-expect-error
14591465
await expect(client.run("owner:abc123", options)).rejects.toThrow();
@@ -1469,6 +1475,7 @@ describe("Replicate client", () => {
14691475
await client.run(
14701476
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
14711477
{
1478+
wait: { mode: "poll" },
14721479
input: {
14731480
text: "Alice",
14741481
},
@@ -1492,7 +1499,7 @@ describe("Replicate client", () => {
14921499
})
14931500
.reply(201, {
14941501
id: "ufawqhfynnddngldkgtslldrkq",
1495-
status: "processing",
1502+
status: "starting",
14961503
})
14971504
.persist()
14981505
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
@@ -1510,6 +1517,7 @@ describe("Replicate client", () => {
15101517
const output = await client.run(
15111518
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
15121519
{
1520+
wait: { mode: "poll" },
15131521
input: { text: "Hello, world!" },
15141522
signal,
15151523
},
@@ -1524,7 +1532,7 @@ describe("Replicate client", () => {
15241532
expect(onProgress).toHaveBeenNthCalledWith(
15251533
1,
15261534
expect.objectContaining({
1527-
status: "processing",
1535+
status: "starting",
15281536
})
15291537
);
15301538
expect(onProgress).toHaveBeenNthCalledWith(
@@ -1580,6 +1588,7 @@ describe("Replicate client", () => {
15801588
const output = (await client.run(
15811589
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
15821590
{
1591+
wait: { mode: "poll" },
15831592
input: { text: "Hello, world!" },
15841593
}
15851594
)) as FileOutput;
@@ -1631,6 +1640,7 @@ describe("Replicate client", () => {
16311640
const output = (await client.run(
16321641
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
16331642
{
1643+
wait: { mode: "poll" },
16341644
input: { text: "Hello, world!" },
16351645
}
16361646
)) as unknown as string;
@@ -1677,6 +1687,7 @@ describe("Replicate client", () => {
16771687
const [output] = (await client.run(
16781688
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
16791689
{
1690+
wait: { mode: "poll" },
16801691
input: { text: "Hello, world!" },
16811692
}
16821693
)) as FileOutput[];
@@ -1724,6 +1735,7 @@ describe("Replicate client", () => {
17241735
const output = (await client.run(
17251736
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
17261737
{
1738+
wait: { mode: "poll" },
17271739
input: { text: "Hello, world!" },
17281740
}
17291741
)) as FileOutput;

package-lock.json

+2-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"name": "replicate",
3-
"version": "0.34.1",
3+
"version": "1.0.0-beta.1",
44
"description": "JavaScript client for Replicate",
55
"repository": "github:replicate/replicate-javascript",
66
"homepage": "https://github.com/replicate/replicate-javascript#readme",

0 commit comments

Comments
 (0)