Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable FileObject and blocking mode by default #316

Merged
merged 2 commits into from
Oct 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@ declare module "replicate" {
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
options: {
input: object;
wait?: boolean | number | { interval?: number };
wait?:
| { mode: "block"; interval?: number; timeout?: number }
| { mode: "poll"; interval?: number };
webhook?: string;
webhook_events_filter?: WebhookEventType[];
signal?: AbortSignal;
Expand Down Expand Up @@ -189,7 +191,6 @@ declare module "replicate" {
wait(
prediction: Prediction,
options?: {
mode?: "poll";
interval?: number;
},
stop?: (prediction: Prediction) => Promise<boolean>
Expand All @@ -215,7 +216,7 @@ declare module "replicate" {
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
wait?: boolean | number | { mode?: "poll"; interval?: number };
wait?: number | boolean;
}
): Promise<Prediction>;
};
Expand Down Expand Up @@ -304,7 +305,7 @@ declare module "replicate" {
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
wait?: boolean | number | { mode?: "poll"; interval?: number };
wait?: boolean | number;
} & ({ version: string } | { model: string })
): Promise<Prediction>;
get(prediction_id: string): Promise<Prediction>;
Expand Down
18 changes: 8 additions & 10 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ class Replicate {
* @param {string} options.userAgent - Identifier of your app
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
* @param {boolean} [options.useFileOutput] - Set to `true` to return `FileOutput` objects from `run` instead of URLs, defaults to false.
* @param {boolean} [options.useFileOutput] - Set to `false` to disable `FileOutput` objects from `run` instead of URLs, defaults to true.
* @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use
*/
constructor(options = {}) {
Expand All @@ -60,7 +60,7 @@ class Replicate {
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
this.fetch = options.fetch || globalThis.fetch;
this.fileEncodingStrategy = options.fileEncodingStrategy || "default";
this.useFileOutput = options.useFileOutput || false;
this.useFileOutput = options.useFileOutput === false ? false : true;

this.accounts = {
current: accounts.current.bind(this),
Expand Down Expand Up @@ -133,8 +133,7 @@ class Replicate {
* @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {object} [options.wait] - Options for waiting for the prediction to finish. If `wait` is explicitly true, the function will block and wait for the prediction to finish.
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500
* @param {{mode: "block", timeout?: number, interval?: number} | {mode: "poll", interval?: number }} [options.wait] - Options for waiting for the prediction to finish. If `wait` is explicitly true, the function will block and wait for the prediction to finish.
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
Expand All @@ -144,23 +143,22 @@ class Replicate {
* @returns {Promise<object>} - Resolves with the output of running the model
*/
async run(ref, options, progress) {
const { wait, signal, ...data } = options;
const { wait = { mode: "block" }, signal, ...data } = options;

const identifier = ModelVersionIdentifier.parse(ref);
const isBlocking = typeof wait === "boolean" || typeof wait === "number";

let prediction;
if (identifier.version) {
prediction = await this.predictions.create({
...data,
version: identifier.version,
wait: isBlocking ? wait : false,
wait: wait.mode === "block" ? wait.timeout ?? true : false,
});
} else if (identifier.owner && identifier.name) {
prediction = await this.predictions.create({
...data,
model: `${identifier.owner}/${identifier.name}`,
wait: isBlocking ? wait : false,
wait: wait.mode === "block" ? wait.timeout ?? true : false,
});
} else {
throw new Error("Invalid model version identifier");
Expand All @@ -171,11 +169,11 @@ class Replicate {
progress(prediction);
}

const isDone = isBlocking && prediction.status !== "starting";
const isDone = wait.mode === "block" && prediction.status !== "starting";
if (!isDone) {
prediction = await this.wait(
prediction,
isBlocking ? {} : wait,
{ interval: wait.mode === "poll" ? wait.interval : undefined },
async (updatedPrediction) => {
// Call progress callback with the updated prediction object
if (progress) {
Expand Down
24 changes: 18 additions & 6 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1310,7 +1310,7 @@ describe("Replicate client", () => {
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
input: { text: "Hello, world!" },
wait: { interval: 1 },
wait: { mode: "poll", interval: 1 },
},
(prediction) => {
const progress = parseProgressFromLogs(prediction);
Expand Down Expand Up @@ -1402,7 +1402,7 @@ describe("Replicate client", () => {
"replicate/hello-world",
{
input: { text: "Hello, world!" },
wait: { interval: 1 },
wait: { mode: "poll", interval: 1 },
},
progress
);
Expand Down Expand Up @@ -1448,12 +1448,18 @@ describe("Replicate client", () => {
});

await expect(
client.run("a/b-1.0:abc123", { input: { text: "Hello, world!" } })
client.run("a/b-1.0:abc123", {
wait: { mode: "poll" },
input: { text: "Hello, world!" },
})
).resolves.not.toThrow();
});

test("Throws an error for invalid identifiers", async () => {
const options = { input: { text: "Hello, world!" } };
const options = {
wait: { mode: "poll" } as { mode: "poll" },
input: { text: "Hello, world!" },
};

// @ts-expect-error
await expect(client.run("owner:abc123", options)).rejects.toThrow();
Expand All @@ -1469,6 +1475,7 @@ describe("Replicate client", () => {
await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
wait: { mode: "poll" },
input: {
text: "Alice",
},
Expand All @@ -1492,7 +1499,7 @@ describe("Replicate client", () => {
})
.reply(201, {
id: "ufawqhfynnddngldkgtslldrkq",
status: "processing",
status: "starting",
})
.persist()
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
Expand All @@ -1510,6 +1517,7 @@ describe("Replicate client", () => {
const output = await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
wait: { mode: "poll" },
input: { text: "Hello, world!" },
signal,
},
Expand All @@ -1524,7 +1532,7 @@ describe("Replicate client", () => {
expect(onProgress).toHaveBeenNthCalledWith(
1,
expect.objectContaining({
status: "processing",
status: "starting",
})
);
expect(onProgress).toHaveBeenNthCalledWith(
Expand Down Expand Up @@ -1580,6 +1588,7 @@ describe("Replicate client", () => {
const output = (await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
wait: { mode: "poll" },
input: { text: "Hello, world!" },
}
)) as FileOutput;
Expand Down Expand Up @@ -1631,6 +1640,7 @@ describe("Replicate client", () => {
const output = (await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
wait: { mode: "poll" },
input: { text: "Hello, world!" },
}
)) as unknown as string;
Expand Down Expand Up @@ -1677,6 +1687,7 @@ describe("Replicate client", () => {
const [output] = (await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
wait: { mode: "poll" },
input: { text: "Hello, world!" },
}
)) as FileOutput[];
Expand Down Expand Up @@ -1724,6 +1735,7 @@ describe("Replicate client", () => {
const output = (await client.run(
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
{
wait: { mode: "poll" },
input: { text: "Hello, world!" },
}
)) as FileOutput;
Expand Down
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "replicate",
"version": "0.34.1",
"version": "1.0.0-beta.1",
"description": "JavaScript client for Replicate",
"repository": "github:replicate/replicate-javascript",
"homepage": "https://github.com/replicate/replicate-javascript#readme",
Expand Down