diff --git a/index.d.ts b/index.d.ts index eabcc9b..4a8bbbf 100644 --- a/index.d.ts +++ b/index.d.ts @@ -162,7 +162,7 @@ declare module "replicate" { signal?: AbortSignal; }, progress?: (prediction: Prediction) => void - ): Promise; + ): Promise; stream( identifier: `${string}/${string}` | `${string}/${string}:${string}`, @@ -215,9 +215,9 @@ declare module "replicate" { stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; - block?: boolean; + wait?: boolean | number | { mode?: "poll"; interval?: number }; } - ): Promise; + ): Promise; }; get( deployment_owner: string, @@ -304,9 +304,9 @@ declare module "replicate" { stream?: boolean; webhook?: string; webhook_events_filter?: WebhookEventType[]; - block?: boolean; + wait?: boolean | number | { mode?: "poll"; interval?: number }; } & ({ version: string } | { model: string }) - ): Promise; + ): Promise; get(prediction_id: string): Promise; cancel(prediction_id: string): Promise; list(): Promise>; diff --git a/index.js b/index.js index 712bc59..c54b80b 100644 --- a/index.js +++ b/index.js @@ -165,6 +165,15 @@ class Replicate { throw new Error("Invalid model version identifier"); } + // When `wait` is set, the server may respond + // with the prediction output directly. + // If it hasn't finished, the prediction object is returned + // with an `id` property that can be used to poll for completion. + if (wait && !("id" in prediction)) { + const output = prediction; + return output; + } + // Call progress callback with the initial prediction object if (progress) { progress(prediction);