Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add wait parameter to prediction creation methods #308

Merged
merged 7 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ declare module "replicate" {
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
options: {
input: object;
wait?: { interval?: number };
wait?: boolean | number | { mode?: "poll"; interval?: number };
webhook?: string;
webhook_events_filter?: WebhookEventType[];
signal?: AbortSignal;
Expand Down Expand Up @@ -189,6 +189,7 @@ declare module "replicate" {
wait(
prediction: Prediction,
options?: {
mode?: "poll";
interval?: number;
},
stop?: (prediction: Prediction) => Promise<boolean>
Expand All @@ -210,9 +211,11 @@ declare module "replicate" {
deployment_name: string,
options: {
input: object;
/** @deprecated */
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
block?: boolean;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I preferred the version we had that exposed wait here. Where wait was one of:

wait?: { mode: "block", timeout?: number } | { mode: "poll", interval?: number }

This way we have only one param consistently, and it defaults to {mode: "block"}.

}
): Promise<Prediction>;
};
Expand Down Expand Up @@ -301,6 +304,7 @@ declare module "replicate" {
stream?: boolean;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
block?: boolean;
} & ({ version: string } | { model: string })
): Promise<Prediction>;
get(prediction_id: string): Promise<Prediction>;
Expand Down
4 changes: 3 additions & 1 deletion index.js
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ class Replicate {
* @param {string} ref - Required. The model version identifier in the format "owner/name" or "owner/name:version"
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {object} [options.wait] - Options for waiting for the prediction to finish
* @param {object} [options.wait] - Options for waiting for the prediction to finish. If `wait` is explicitly true, the function will block and wait for the prediction to finish.
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
Expand All @@ -153,11 +153,13 @@ class Replicate {
prediction = await this.predictions.create({
...data,
version: identifier.version,
wait: wait,
});
} else if (identifier.owner && identifier.name) {
prediction = await this.predictions.create({
...data,
model: `${identifier.owner}/${identifier.name}`,
wait: wait,
});
} else {
throw new Error("Invalid model version identifier");
Expand Down
11 changes: 8 additions & 3 deletions lib/deployments.js
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ const { transformFileInputs } = require("./util");
* @param {string} deployment_name - Required. The name of the deployment
* @param {object} options
* @param {object} options.input - Required. An object with the model inputs
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {boolean} [options.block] - Whether to wait until the prediction is completed before returning. Defaults to false
* @returns {Promise<object>} Resolves with the created prediction data
*/
async function createPrediction(deployment_owner, deployment_name, options) {
const { stream, input, ...data } = options;
const { input, block, ...data } = options;

if (data.webhook) {
try {
Expand All @@ -24,18 +24,23 @@ async function createPrediction(deployment_owner, deployment_name, options) {
}
}

const headers = {};
if (block) {
headers["Prefer"] = "wait";
}

const response = await this.request(
`/deployments/${deployment_owner}/${deployment_name}/predictions`,
{
method: "POST",
headers,
data: {
...data,
input: await transformFileInputs(
this,
input,
this.fileEncodingStrategy
),
stream,
},
}
);
Expand Down
16 changes: 14 additions & 2 deletions lib/predictions.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ const { transformFileInputs } = require("./util");
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false. Streaming is now enabled by default for all predictions. For more information, see https://replicate.com/changelog/2024-07-15-streams-always-available-stream-parameter-deprecated
* @param {boolean|integer} [options.wait] - Whether to wait until the prediction is completed before returning. If an integer is provided, it will wait for that many seconds. Defaults to false
* @returns {Promise<object>} Resolves with the created prediction
*/
async function createPrediction(options) {
const { model, version, input, ...data } = options;
const { model, version, input, wait, ...data } = options;

if (data.webhook) {
try {
Expand All @@ -24,10 +24,21 @@ async function createPrediction(options) {
}
}

const headers = {};
if (wait) {
if (typeof wait === "number") {
const n = Math.max(1, Math.ceil(Number(wait)) || 1);
headers["Prefer"] = `wait=${n}`;
} else {
headers["Prefer"] = "wait";
}
}

let response;
if (version) {
response = await this.request("/predictions", {
method: "POST",
headers,
data: {
...data,
input: await transformFileInputs(
Expand All @@ -41,6 +52,7 @@ async function createPrediction(options) {
} else if (model) {
response = await this.request(`/models/${model}/predictions`, {
method: "POST",
headers,
data: {
...data,
input: await transformFileInputs(
Expand Down