Skip to content

Commit c1a12b0

Browse files
authored
Bug fixes for the wait option in replicate.run (#315)
There were a couple of small bugs in the current implementation: 1. We would pass non-boolean, non-integer values through to `predictions.create` when it was an object with an interval, resulting in the blocking mode being used accidentally. 2. We would pass the boolean/integer values through to `wait` which would create a runtime error when the `wait` function expects an object. 3. We continued to poll for the prediction despite the blocking response returning the output data. This PR addresses these three issues by checking if the run should be blocking and passing the correct arguments in the correct places. We also assume that if the returned prediction is not in `starting` state then it is completed. This isn't ideal but works for the moment. Lastly, in the case where the blocking request times out the client will fall back to polling at the default interval.
1 parent be0f323 commit c1a12b0

File tree

2 files changed

+27
-23
lines changed

2 files changed

+27
-23
lines changed

index.d.ts

+5-5
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ declare module "replicate" {
9393
model: string;
9494
version: string;
9595
input: object;
96-
output?: any;
96+
output?: any; // TODO: this should be `unknown`
9797
source: "api" | "web";
98-
error?: any;
98+
error?: unknown;
9999
logs?: string;
100100
metrics?: {
101101
predict_time?: number;
@@ -156,7 +156,7 @@ declare module "replicate" {
156156
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
157157
options: {
158158
input: object;
159-
wait?: boolean | number | { mode?: "poll"; interval?: number };
159+
wait?: boolean | number | { interval?: number };
160160
webhook?: string;
161161
webhook_events_filter?: WebhookEventType[];
162162
signal?: AbortSignal;
@@ -215,7 +215,7 @@ declare module "replicate" {
215215
stream?: boolean;
216216
webhook?: string;
217217
webhook_events_filter?: WebhookEventType[];
218-
block?: boolean;
218+
wait?: boolean | number | { mode?: "poll"; interval?: number };
219219
}
220220
): Promise<Prediction>;
221221
};
@@ -304,7 +304,7 @@ declare module "replicate" {
304304
stream?: boolean;
305305
webhook?: string;
306306
webhook_events_filter?: WebhookEventType[];
307-
block?: boolean;
307+
wait?: boolean | number | { mode?: "poll"; interval?: number };
308308
} & ({ version: string } | { model: string })
309309
): Promise<Prediction>;
310310
get(prediction_id: string): Promise<Prediction>;

index.js

+22-18
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,20 @@ class Replicate {
147147
const { wait, signal, ...data } = options;
148148

149149
const identifier = ModelVersionIdentifier.parse(ref);
150+
const isBlocking = typeof wait === "boolean" || typeof wait === "number";
150151

151152
let prediction;
152153
if (identifier.version) {
153154
prediction = await this.predictions.create({
154155
...data,
155156
version: identifier.version,
156-
wait: wait,
157+
wait: isBlocking ? wait : false,
157158
});
158159
} else if (identifier.owner && identifier.name) {
159160
prediction = await this.predictions.create({
160161
...data,
161162
model: `${identifier.owner}/${identifier.name}`,
162-
wait: wait,
163+
wait: isBlocking ? wait : false,
163164
});
164165
} else {
165166
throw new Error("Invalid model version identifier");
@@ -170,23 +171,26 @@ class Replicate {
170171
progress(prediction);
171172
}
172173

173-
prediction = await this.wait(
174-
prediction,
175-
wait || {},
176-
async (updatedPrediction) => {
177-
// Call progress callback with the updated prediction object
178-
if (progress) {
179-
progress(updatedPrediction);
174+
const isDone = isBlocking && prediction.status !== "starting";
175+
if (!isDone) {
176+
prediction = await this.wait(
177+
prediction,
178+
isBlocking ? {} : wait,
179+
async (updatedPrediction) => {
180+
// Call progress callback with the updated prediction object
181+
if (progress) {
182+
progress(updatedPrediction);
183+
}
184+
185+
// We handle the cancel later in the function.
186+
if (signal && signal.aborted) {
187+
return true; // stop polling
188+
}
189+
190+
return false; // continue polling
180191
}
181-
182-
// We handle the cancel later in the function.
183-
if (signal && signal.aborted) {
184-
return true; // stop polling
185-
}
186-
187-
return false; // continue polling
188-
}
189-
);
192+
);
193+
}
190194

191195
if (signal && signal.aborted) {
192196
prediction = await this.predictions.cancel(prediction.id);

0 commit comments

Comments
 (0)