diff --git a/index.d.ts b/index.d.ts index 78457e0..eabcc9b 100644 --- a/index.d.ts +++ b/index.d.ts @@ -156,7 +156,7 @@ declare module "replicate" { identifier: `${string}/${string}` | `${string}/${string}:${string}`, options: { input: object; - wait?: { interval?: number }; + wait?: boolean | number | { mode?: "poll"; interval?: number }; webhook?: string; webhook_events_filter?: WebhookEventType[]; signal?: AbortSignal; @@ -189,6 +189,7 @@ declare module "replicate" { wait( prediction: Prediction, options?: { + mode?: "poll"; interval?: number; }, stop?: (prediction: Prediction) => Promise @@ -210,9 +211,11 @@ declare module "replicate" { deployment_name: string, options: { input: object; + /** @deprecated */ stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; + block?: boolean; } ): Promise; }; @@ -301,6 +304,7 @@ declare module "replicate" { stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; + block?: boolean; } & ({ version: string } | { model: string }) ): Promise; get(prediction_id: string): Promise; diff --git a/index.js b/index.js index f0d3e75..712bc59 100644 --- a/index.js +++ b/index.js @@ -133,7 +133,7 @@ class Replicate { * @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version" * @param {object} options * @param {object} options.input - Required. An object with the model inputs - * @param {object} [options.wait] - Options for waiting for the prediction to finish + * @param {object} [options.wait] - Options for waiting for the prediction to finish. If `wait` is explicitly true, the function will block and wait for the prediction to finish. * @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500 * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) @@ -153,11 +153,13 @@ class Replicate { prediction = await this.predictions.create({ ...data, version: identifier.version, + wait: wait, }); } else if (identifier.owner && identifier.name) { prediction = await this.predictions.create({ ...data, model: `${identifier.owner}/${identifier.name}`, + wait: wait, }); } else { throw new Error("Invalid model version identifier"); diff --git a/lib/deployments.js b/lib/deployments.js index 56ed240..6cab261 100644 --- a/lib/deployments.js +++ b/lib/deployments.js @@ -7,13 +7,13 @@ const { transformFileInputs } = require("./util"); * @param {string} deployment_name - Required. The name of the deployment * @param {object} options * @param {object} options.input - Required. An object with the model inputs - * @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) + * @param {boolean} [options.block] - Whether to wait until the prediction is completed before returning. Defaults to false * @returns {Promise} Resolves with the created prediction data */ async function createPrediction(deployment_owner, deployment_name, options) { - const { stream, input, ...data } = options; + const { input, block, ...data } = options; if (data.webhook) { try { @@ -24,10 +24,16 @@ async function createPrediction(deployment_owner, deployment_name, options) { } } + const headers = {}; + if (block) { + headers["Prefer"] = "wait"; + } + const response = await this.request( `/deployments/${deployment_owner}/${deployment_name}/predictions`, { method: "POST", + headers, data: { ...data, input: await transformFileInputs( @@ -35,7 +41,6 @@ async function createPrediction(deployment_owner, deployment_name, options) { input, this.fileEncodingStrategy ), - stream, }, } ); diff --git a/lib/predictions.js b/lib/predictions.js index 88bed32..f8e1c5a 100644 --- a/lib/predictions.js +++ b/lib/predictions.js @@ -9,11 +9,11 @@ const { transformFileInputs } = require("./util"); * @param {object} options.input - Required. An object with the model inputs * @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output * @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`) - * @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false. Streaming is now enabled by default for all predictions. For more information, see https://replicate.com/changelog/2024-07-15-streams-always-available-stream-parameter-deprecated + * @param {boolean|integer} [options.wait] - Whether to wait until the prediction is completed before returning. If an integer is provided, it will wait for that many seconds. Defaults to false * @returns {Promise} Resolves with the created prediction */ async function createPrediction(options) { - const { model, version, input, ...data } = options; + const { model, version, input, wait, ...data } = options; if (data.webhook) { try { @@ -24,10 +24,21 @@ async function createPrediction(options) { } } + const headers = {}; + if (wait) { + if (typeof wait === "number") { + const n = Math.max(1, Math.ceil(Number(wait)) || 1); + headers["Prefer"] = `wait=${n}`; + } else { + headers["Prefer"] = "wait"; + } + } + let response; if (version) { response = await this.request("/predictions", { method: "POST", + headers, data: { ...data, input: await transformFileInputs( @@ -41,6 +52,7 @@ async function createPrediction(options) { } else if (model) { response = await this.request(`/models/${model}/predictions`, { method: "POST", + headers, data: { ...data, input: await transformFileInputs(