Skip to content

Commit 173b31d

Browse files
matttaron
andauthored
Add wait parameter to prediction creation methods (#308)
* Add block parameter to prediction creation methods * Add support for block: or wait: true to run * Remove unused stream parameter * Apply suggestions from code review Co-authored-by: Aron Carroll <[email protected]> * Replace X-Sync header with Prefer: wait * Rename block to wait Update type definitions to capture expectations of wait parameter in run * Normalize wait value to nonzero integer --------- Co-authored-by: Aron Carroll <[email protected]>
1 parent 7f02a64 commit 173b31d

File tree

4 files changed

+30
-7
lines changed

4 files changed

+30
-7
lines changed

index.d.ts

+5-1
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ declare module "replicate" {
156156
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
157157
options: {
158158
input: object;
159-
wait?: { interval?: number };
159+
wait?: boolean | number | { mode?: "poll"; interval?: number };
160160
webhook?: string;
161161
webhook_events_filter?: WebhookEventType[];
162162
signal?: AbortSignal;
@@ -189,6 +189,7 @@ declare module "replicate" {
189189
wait(
190190
prediction: Prediction,
191191
options?: {
192+
mode?: "poll";
192193
interval?: number;
193194
},
194195
stop?: (prediction: Prediction) => Promise<boolean>
@@ -210,9 +211,11 @@ declare module "replicate" {
210211
deployment_name: string,
211212
options: {
212213
input: object;
214+
/** @deprecated */
213215
stream?: boolean;
214216
webhook?: string;
215217
webhook_events_filter?: WebhookEventType[];
218+
block?: boolean;
216219
}
217220
): Promise<Prediction>;
218221
};
@@ -301,6 +304,7 @@ declare module "replicate" {
301304
stream?: boolean;
302305
webhook?: string;
303306
webhook_events_filter?: WebhookEventType[];
307+
block?: boolean;
304308
} & ({ version: string } | { model: string })
305309
): Promise<Prediction>;
306310
get(prediction_id: string): Promise<Prediction>;

index.js

+3-1
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ class Replicate {
133133
* @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
134134
* @param {object} options
135135
* @param {object} options.input - Required. An object with the model inputs
136-
* @param {object} [options.wait] - Options for waiting for the prediction to finish
136+
* @param {object} [options.wait] - Options for waiting for the prediction to finish. If `wait` is explicitly true, the function will block and wait for the prediction to finish.
137137
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500
138138
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
139139
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
@@ -153,11 +153,13 @@ class Replicate {
153153
prediction = await this.predictions.create({
154154
...data,
155155
version: identifier.version,
156+
wait: wait,
156157
});
157158
} else if (identifier.owner && identifier.name) {
158159
prediction = await this.predictions.create({
159160
...data,
160161
model: `${identifier.owner}/${identifier.name}`,
162+
wait: wait,
161163
});
162164
} else {
163165
throw new Error("Invalid model version identifier");

lib/deployments.js

+8-3
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@ const { transformFileInputs } = require("./util");
77
* @param {string} deployment_name - Required. The name of the deployment
88
* @param {object} options
99
* @param {object} options.input - Required. An object with the model inputs
10-
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
1110
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
1211
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
12+
* @param {boolean} [options.block] - Whether to wait until the prediction is completed before returning. Defaults to false
1313
* @returns {Promise<object>} Resolves with the created prediction data
1414
*/
1515
async function createPrediction(deployment_owner, deployment_name, options) {
16-
const { stream, input, ...data } = options;
16+
const { input, block, ...data } = options;
1717

1818
if (data.webhook) {
1919
try {
@@ -24,18 +24,23 @@ async function createPrediction(deployment_owner, deployment_name, options) {
2424
}
2525
}
2626

27+
const headers = {};
28+
if (block) {
29+
headers["Prefer"] = "wait";
30+
}
31+
2732
const response = await this.request(
2833
`/deployments/${deployment_owner}/${deployment_name}/predictions`,
2934
{
3035
method: "POST",
36+
headers,
3137
data: {
3238
...data,
3339
input: await transformFileInputs(
3440
this,
3541
input,
3642
this.fileEncodingStrategy
3743
),
38-
stream,
3944
},
4045
}
4146
);

lib/predictions.js

+14-2
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ const { transformFileInputs } = require("./util");
99
* @param {object} options.input - Required. An object with the model inputs
1010
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
1111
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
12-
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false. Streaming is now enabled by default for all predictions. For more information, see https://replicate.com/changelog/2024-07-15-streams-always-available-stream-parameter-deprecated
12+
* @param {boolean|integer} [options.wait] - Whether to wait until the prediction is completed before returning. If an integer is provided, it will wait for that many seconds. Defaults to false
1313
* @returns {Promise<object>} Resolves with the created prediction
1414
*/
1515
async function createPrediction(options) {
16-
const { model, version, input, ...data } = options;
16+
const { model, version, input, wait, ...data } = options;
1717

1818
if (data.webhook) {
1919
try {
@@ -24,10 +24,21 @@ async function createPrediction(options) {
2424
}
2525
}
2626

27+
const headers = {};
28+
if (wait) {
29+
if (typeof wait === "number") {
30+
const n = Math.max(1, Math.ceil(Number(wait)) || 1);
31+
headers["Prefer"] = `wait=${n}`;
32+
} else {
33+
headers["Prefer"] = "wait";
34+
}
35+
}
36+
2737
let response;
2838
if (version) {
2939
response = await this.request("/predictions", {
3040
method: "POST",
41+
headers,
3142
data: {
3243
...data,
3344
input: await transformFileInputs(
@@ -41,6 +52,7 @@ async function createPrediction(options) {
4152
} else if (model) {
4253
response = await this.request(`/models/${model}/predictions`, {
4354
method: "POST",
55+
headers,
4456
data: {
4557
...data,
4658
input: await transformFileInputs(

0 commit comments

Comments
 (0)