Skip to content

Commit ac5caba

Browse files
authored
Add support for AbortSignal to all API methods (#339)
This PR adds support for passing `AbortSignal` to all API methods that make HTTP requests, these are passed directly into the native `fetch()` implementation, so it's up to the user to handle the `AbortError` raised, if any. ```js const controller = new AbortController(); try { const prediction = await replicate.predictions.create({ version: 'xyz', ..., signal: controller.signal, }); } catch (err) { if (err instanceof DOMException && err.name === "AbortError") { ... } } ``` The `paginate` function also checks to see whether the signal was aborted before proceeding to the next iteration. If so it returns immediately to avoid making a redundant fetch call. This allows the client to take advantage of various frameworks that provide an `AbortSignal` instance to tear down any in flight requests.
1 parent 37aa363 commit ac5caba

15 files changed

+326
-82
lines changed

README.md

+10-4
Original file line numberDiff line numberDiff line change
@@ -1213,15 +1213,21 @@ Low-level method used by the Replicate client to interact with API endpoints.
12131213
const response = await replicate.request(route, parameters);
12141214
```
12151215

1216-
| name | type | description |
1217-
| -------------------- | ------ | ------------------------------------------------------------ |
1218-
| `options.route` | string | Required. REST API endpoint path. |
1219-
| `options.parameters` | object | URL, query, and request body parameters for the given route. |
1216+
| name | type | description |
1217+
| -------------------- | ------------------- | ----------- |
1218+
| `options.route` | `string` | Required. REST API endpoint path.
1219+
| `options.params` | `object` | URL query parameters for the given route. |
1220+
| `options.method` | `string` | HTTP method for the given route. |
1221+
| `options.headers` | `object` | Additional HTTP headers for the given route. |
1222+
| `options.data` | `object \| FormData` | Request body. |
1223+
| `options.signal` | `AbortSignal` | Optional `AbortSignal`. |
12201224

12211225
The `replicate.request()` method is used by the other methods
12221226
to interact with the Replicate API.
12231227
You can call this method directly to make other requests to the API.
12241228

1229+
The method accepts an `AbortSignal` which can be used to cancel the request in flight.
1230+
12251231
### `FileOutput`
12261232

12271233
`FileOutput` is a `ReadableStream` instance that represents a model file output. It can be used to stream file data to disk or as a `Response` body to an HTTP request.

index.d.ts

+81-33
Original file line numberDiff line numberDiff line change
@@ -183,10 +183,14 @@ declare module "replicate" {
183183
headers?: object | Headers;
184184
params?: object;
185185
data?: object;
186+
signal?: AbortSignal;
186187
}
187188
): Promise<Response>;
188189

189-
paginate<T>(endpoint: () => Promise<Page<T>>): AsyncGenerator<[T]>;
190+
paginate<T>(
191+
endpoint: () => Promise<Page<T>>,
192+
options?: { signal?: AbortSignal }
193+
): AsyncGenerator<T[]>;
190194

191195
wait(
192196
prediction: Prediction,
@@ -197,12 +201,15 @@ declare module "replicate" {
197201
): Promise<Prediction>;
198202

199203
accounts: {
200-
current(): Promise<Account>;
204+
current(options?: { signal?: AbortSignal }): Promise<Account>;
201205
};
202206

203207
collections: {
204-
list(): Promise<Page<Collection>>;
205-
get(collection_slug: string): Promise<Collection>;
208+
list(options?: { signal?: AbortSignal }): Promise<Page<Collection>>;
209+
get(
210+
collection_slug: string,
211+
options?: { signal?: AbortSignal }
212+
): Promise<Collection>;
206213
};
207214

208215
deployments: {
@@ -217,21 +224,26 @@ declare module "replicate" {
217224
webhook?: string;
218225
webhook_events_filter?: WebhookEventType[];
219226
wait?: number | boolean;
227+
signal?: AbortSignal;
220228
}
221229
): Promise<Prediction>;
222230
};
223231
get(
224232
deployment_owner: string,
225-
deployment_name: string
233+
deployment_name: string,
234+
options?: { signal?: AbortSignal }
235+
): Promise<Deployment>;
236+
create(
237+
deployment_config: {
238+
name: string;
239+
model: string;
240+
version: string;
241+
hardware: string;
242+
min_instances: number;
243+
max_instances: number;
244+
},
245+
options?: { signal?: AbortSignal }
226246
): Promise<Deployment>;
227-
create(deployment_config: {
228-
name: string;
229-
model: string;
230-
version: string;
231-
hardware: string;
232-
min_instances: number;
233-
max_instances: number;
234-
}): Promise<Deployment>;
235247
update(
236248
deployment_owner: string,
237249
deployment_name: string,
@@ -245,32 +257,45 @@ declare module "replicate" {
245257
| { hardware: string }
246258
| { min_instances: number }
247259
| { max_instances: number }
248-
)
260+
),
261+
options?: { signal?: AbortSignal }
249262
): Promise<Deployment>;
250263
delete(
251264
deployment_owner: string,
252-
deployment_name: string
265+
deployment_name: string,
266+
options?: { signal?: AbortSignal }
253267
): Promise<boolean>;
254-
list(): Promise<Page<Deployment>>;
268+
list(options?: { signal?: AbortSignal }): Promise<Page<Deployment>>;
255269
};
256270

257271
files: {
258272
create(
259273
file: Blob | Buffer,
260-
metadata?: Record<string, unknown>
274+
metadata?: Record<string, unknown>,
275+
options?: { signal?: AbortSignal }
261276
): Promise<FileObject>;
262-
list(): Promise<Page<FileObject>>;
263-
get(file_id: string): Promise<FileObject>;
264-
delete(file_id: string): Promise<boolean>;
277+
list(options?: { signal?: AbortSignal }): Promise<Page<FileObject>>;
278+
get(
279+
file_id: string,
280+
options?: { signal?: AbortSignal }
281+
): Promise<FileObject>;
282+
delete(
283+
file_id: string,
284+
options?: { signal?: AbortSignal }
285+
): Promise<boolean>;
265286
};
266287

267288
hardware: {
268-
list(): Promise<Hardware[]>;
289+
list(options?: { signal?: AbortSignal }): Promise<Hardware[]>;
269290
};
270291

271292
models: {
272-
get(model_owner: string, model_name: string): Promise<Model>;
273-
list(): Promise<Page<Model>>;
293+
get(
294+
model_owner: string,
295+
model_name: string,
296+
options?: { signal?: AbortSignal }
297+
): Promise<Model>;
298+
list(options?: { signal?: AbortSignal }): Promise<Page<Model>>;
274299
create(
275300
model_owner: string,
276301
model_name: string,
@@ -282,17 +307,26 @@ declare module "replicate" {
282307
paper_url?: string;
283308
license_url?: string;
284309
cover_image_url?: string;
310+
signal?: AbortSignal;
285311
}
286312
): Promise<Model>;
287313
versions: {
288-
list(model_owner: string, model_name: string): Promise<ModelVersion[]>;
314+
list(
315+
model_owner: string,
316+
model_name: string,
317+
options?: { signal?: AbortSignal }
318+
): Promise<ModelVersion[]>;
289319
get(
290320
model_owner: string,
291321
model_name: string,
292-
version_id: string
322+
version_id: string,
323+
options?: { signal?: AbortSignal }
293324
): Promise<ModelVersion>;
294325
};
295-
search(query: string): Promise<Page<Model>>;
326+
search(
327+
query: string,
328+
options?: { signal?: AbortSignal }
329+
): Promise<Page<Model>>;
296330
};
297331

298332
predictions: {
@@ -306,11 +340,18 @@ declare module "replicate" {
306340
webhook?: string;
307341
webhook_events_filter?: WebhookEventType[];
308342
wait?: boolean | number;
343+
signal?: AbortSignal;
309344
} & ({ version: string } | { model: string })
310345
): Promise<Prediction>;
311-
get(prediction_id: string): Promise<Prediction>;
312-
cancel(prediction_id: string): Promise<Prediction>;
313-
list(): Promise<Page<Prediction>>;
346+
get(
347+
prediction_id: string,
348+
options?: { signal?: AbortSignal }
349+
): Promise<Prediction>;
350+
cancel(
351+
prediction_id: string,
352+
options?: { signal?: AbortSignal }
353+
): Promise<Prediction>;
354+
list(options?: { signal?: AbortSignal }): Promise<Page<Prediction>>;
314355
};
315356

316357
trainings: {
@@ -323,17 +364,24 @@ declare module "replicate" {
323364
input: object;
324365
webhook?: string;
325366
webhook_events_filter?: WebhookEventType[];
367+
signal?: AbortSignal;
326368
}
327369
): Promise<Training>;
328-
get(training_id: string): Promise<Training>;
329-
cancel(training_id: string): Promise<Training>;
330-
list(): Promise<Page<Training>>;
370+
get(
371+
training_id: string,
372+
options?: { signal?: AbortSignal }
373+
): Promise<Training>;
374+
cancel(
375+
training_id: string,
376+
options?: { signal?: AbortSignal }
377+
): Promise<Training>;
378+
list(options?: { signal?: AbortSignal }): Promise<Page<Training>>;
331379
};
332380

333381
webhooks: {
334382
default: {
335383
secret: {
336-
get(): Promise<WebhookSecret>;
384+
get(options?: { signal?: AbortSignal }): Promise<WebhookSecret>;
337385
};
338386
};
339387
};

index.js

+12-5
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ class Replicate {
225225
* @param {object} [options.params] - Query parameters
226226
* @param {object|Headers} [options.headers] - HTTP headers
227227
* @param {object} [options.data] - Body parameters
228+
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the request
228229
* @returns {Promise<Response>} - Resolves with the response object
229230
* @throws {ApiError} If the request failed
230231
*/
@@ -241,7 +242,7 @@ class Replicate {
241242
);
242243
}
243244

244-
const { method = "GET", params = {}, data } = options;
245+
const { method = "GET", params = {}, data, signal } = options;
245246

246247
for (const [key, value] of Object.entries(params)) {
247248
url.searchParams.append(key, value);
@@ -273,6 +274,7 @@ class Replicate {
273274
method,
274275
headers,
275276
body,
277+
signal,
276278
};
277279

278280
const shouldRetry =
@@ -354,15 +356,20 @@ class Replicate {
354356
* console.log(page);
355357
* }
356358
* @param {Function} endpoint - Function that returns a promise for the next page of results
359+
* @param {object} [options]
360+
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the request.
357361
* @yields {object[]} Each page of results
358362
*/
359-
async *paginate(endpoint) {
363+
async *paginate(endpoint, options = {}) {
360364
const response = await endpoint();
361365
yield response.results;
362-
if (response.next) {
366+
if (response.next && !(options.signal && options.signal.aborted)) {
363367
const nextPage = () =>
364-
this.request(response.next, { method: "GET" }).then((r) => r.json());
365-
yield* this.paginate(nextPage);
368+
this.request(response.next, {
369+
method: "GET",
370+
signal: options.signal,
371+
}).then((r) => r.json());
372+
yield* this.paginate(nextPage, options);
366373
}
367374
}
368375

index.test.ts

+84
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,90 @@ describe("Replicate client", () => {
9999
});
100100
});
101101

102+
describe("paginate", () => {
103+
test("pages through results", async () => {
104+
nock(BASE_URL)
105+
.get("/collections")
106+
.reply(200, {
107+
results: [
108+
{
109+
name: "Super resolution",
110+
slug: "super-resolution",
111+
description:
112+
"Upscaling models that create high-quality images from low-quality images.",
113+
},
114+
],
115+
next: `${BASE_URL}/collections?page=2`,
116+
previous: null,
117+
});
118+
nock(BASE_URL)
119+
.get("/collections?page=2")
120+
.reply(200, {
121+
results: [
122+
{
123+
name: "Image classification",
124+
slug: "image-classification",
125+
description: "Models that classify images.",
126+
},
127+
],
128+
next: null,
129+
previous: null,
130+
});
131+
132+
const iterator = client.paginate(client.collections.list);
133+
134+
const firstPage = (await iterator.next()).value;
135+
expect(firstPage.length).toBe(1);
136+
137+
const secondPage = (await iterator.next()).value;
138+
expect(secondPage.length).toBe(1);
139+
});
140+
141+
test("accepts an abort signal", async () => {
142+
nock(BASE_URL)
143+
.get("/collections")
144+
.reply(200, {
145+
results: [
146+
{
147+
name: "Super resolution",
148+
slug: "super-resolution",
149+
description:
150+
"Upscaling models that create high-quality images from low-quality images.",
151+
},
152+
],
153+
next: `${BASE_URL}/collections?page=2`,
154+
previous: null,
155+
});
156+
nock(BASE_URL)
157+
.get("/collections?page=2")
158+
.reply(200, {
159+
results: [
160+
{
161+
name: "Image classification",
162+
slug: "image-classification",
163+
description: "Models that classify images.",
164+
},
165+
],
166+
next: null,
167+
previous: null,
168+
});
169+
170+
const controller = new AbortController();
171+
const iterator = client.paginate(client.collections.list, {
172+
signal: controller.signal,
173+
});
174+
175+
const firstIteration = await iterator.next();
176+
expect(firstIteration.value.length).toBe(1);
177+
178+
controller.abort();
179+
180+
const secondIteration = await iterator.next();
181+
expect(secondIteration.value).toBeUndefined();
182+
expect(secondIteration.done).toBe(true);
183+
});
184+
});
185+
102186
describe("account.get", () => {
103187
test("Calls the correct API route", async () => {
104188
nock(BASE_URL).get("/account").reply(200, {

integration/next/pages/index.js

+4-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
export default () => (
2-
<main>
3-
<h1>Welcome to Next.js</h1>
4-
</main>
5-
)
2+
<main>
3+
<h1>Welcome to Next.js</h1>
4+
</main>
5+
);

0 commit comments

Comments
 (0)