Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for AbortSignal to all API methods #339

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -1208,13 +1208,19 @@ const response = await replicate.request(route, parameters);

| name | type | description |
| -------------------- | ------ | ------------------------------------------------------------ |
| `options.route` | string | Required. REST API endpoint path. |
| `options.parameters` | object | URL, query, and request body parameters for the given route. |
| `options.route` | `string` | Required. REST API endpoint path. |
| `options.params` | `object` | URL query parameters for the given route. |
| `options.method` | `string` | HTTP method for the given route. |
| `options.headers` | `object` | Additional HTTP headers for the given route. |
| `options.data` | `object | FormData` | Request body. |
| `options.signal` | `AbortSignal` | Optional `AbortSignal`. |
Comment on lines +1211 to +1216
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you fix the formatting here so the table's dividers align vertically, please?


The `replicate.request()` method is used by the other methods
to interact with the Replicate API.
You can call this method directly to make other requests to the API.

The method accepts an `AbortSignal` which can be used to cancel the request in flight.

### `FileOutput`

`FileOutput` is a `ReadableStream` instance that represents a model file output. It can be used to stream file data to disk or as a `Response` body to an HTTP request.
Expand Down
114 changes: 81 additions & 33 deletions index.d.ts
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,14 @@ declare module "replicate" {
headers?: object | Headers;
params?: object;
data?: object;
signal?: AbortSignal;
}
): Promise<Response>;

paginate<T>(endpoint: () => Promise<Page<T>>): AsyncGenerator<[T]>;
paginate<T>(
endpoint: () => Promise<Page<T>>,
options?: { signal?: AbortSignal }
): AsyncGenerator<T[]>;

wait(
prediction: Prediction,
Expand All @@ -197,12 +201,15 @@ declare module "replicate" {
): Promise<Prediction>;

accounts: {
current(): Promise<Account>;
current(options?: { signal?: AbortSignal }): Promise<Account>;
};

collections: {
list(): Promise<Page<Collection>>;
get(collection_slug: string): Promise<Collection>;
list(options?: { signal?: AbortSignal }): Promise<Page<Collection>>;
get(
collection_slug: string,
options?: { signal?: AbortSignal }
): Promise<Collection>;
};

deployments: {
Expand All @@ -217,21 +224,26 @@ declare module "replicate" {
webhook?: string;
webhook_events_filter?: WebhookEventType[];
wait?: number | boolean;
signal?: AbortSignal;
}
): Promise<Prediction>;
};
get(
deployment_owner: string,
deployment_name: string
deployment_name: string,
options?: { signal?: AbortSignal }
): Promise<Deployment>;
create(
deployment_config: {
name: string;
model: string;
version: string;
hardware: string;
min_instances: number;
max_instances: number;
},
options?: { signal?: AbortSignal }
): Promise<Deployment>;
create(deployment_config: {
name: string;
model: string;
version: string;
hardware: string;
min_instances: number;
max_instances: number;
}): Promise<Deployment>;
update(
deployment_owner: string,
deployment_name: string,
Expand All @@ -245,32 +257,45 @@ declare module "replicate" {
| { hardware: string }
| { min_instances: number }
| { max_instances: number }
)
),
options?: { signal?: AbortSignal }
): Promise<Deployment>;
delete(
deployment_owner: string,
deployment_name: string
deployment_name: string,
options?: { signal?: AbortSignal }
): Promise<boolean>;
list(): Promise<Page<Deployment>>;
list(options?: { signal?: AbortSignal }): Promise<Page<Deployment>>;
};

files: {
create(
file: Blob | Buffer,
metadata?: Record<string, unknown>
metadata?: Record<string, unknown>,
options?: { signal?: AbortSignal }
): Promise<FileObject>;
list(options?: { signal?: AbortSignal }): Promise<Page<FileObject>>;
get(
file_id: string,
options?: { signal?: AbortSignal }
): Promise<FileObject>;
list(): Promise<Page<FileObject>>;
get(file_id: string): Promise<FileObject>;
delete(file_id: string): Promise<boolean>;
delete(
file_id: string,
options?: { signal?: AbortSignal }
): Promise<boolean>;
};

hardware: {
list(): Promise<Hardware[]>;
list(options?: { signal?: AbortSignal }): Promise<Hardware[]>;
};

models: {
get(model_owner: string, model_name: string): Promise<Model>;
list(): Promise<Page<Model>>;
get(
model_owner: string,
model_name: string,
options?: { signal?: AbortSignal }
): Promise<Model>;
list(options?: { signal?: AbortSignal }): Promise<Page<Model>>;
create(
model_owner: string,
model_name: string,
Expand All @@ -282,17 +307,26 @@ declare module "replicate" {
paper_url?: string;
license_url?: string;
cover_image_url?: string;
signal?: AbortSignal;
}
): Promise<Model>;
versions: {
list(model_owner: string, model_name: string): Promise<ModelVersion[]>;
list(
model_owner: string,
model_name: string,
options?: { signal?: AbortSignal }
): Promise<ModelVersion[]>;
get(
model_owner: string,
model_name: string,
version_id: string
version_id: string,
options?: { signal?: AbortSignal }
): Promise<ModelVersion>;
};
search(query: string): Promise<Page<Model>>;
search(
query: string,
options?: { signal?: AbortSignal }
): Promise<Page<Model>>;
};

predictions: {
Expand All @@ -306,11 +340,18 @@ declare module "replicate" {
webhook?: string;
webhook_events_filter?: WebhookEventType[];
wait?: boolean | number;
signal?: AbortSignal;
} & ({ version: string } | { model: string })
): Promise<Prediction>;
get(prediction_id: string): Promise<Prediction>;
cancel(prediction_id: string): Promise<Prediction>;
list(): Promise<Page<Prediction>>;
get(
prediction_id: string,
options?: { signal?: AbortSignal }
): Promise<Prediction>;
cancel(
prediction_id: string,
options?: { signal?: AbortSignal }
): Promise<Prediction>;
list(options?: { signal?: AbortSignal }): Promise<Page<Prediction>>;
};

trainings: {
Expand All @@ -323,17 +364,24 @@ declare module "replicate" {
input: object;
webhook?: string;
webhook_events_filter?: WebhookEventType[];
signal?: AbortSignal;
}
): Promise<Training>;
get(training_id: string): Promise<Training>;
cancel(training_id: string): Promise<Training>;
list(): Promise<Page<Training>>;
get(
training_id: string,
options?: { signal?: AbortSignal }
): Promise<Training>;
cancel(
training_id: string,
options?: { signal?: AbortSignal }
): Promise<Training>;
list(options?: { signal?: AbortSignal }): Promise<Page<Training>>;
};

webhooks: {
default: {
secret: {
get(): Promise<WebhookSecret>;
get(options?: { signal?: AbortSignal }): Promise<WebhookSecret>;
};
};
};
Expand Down
17 changes: 12 additions & 5 deletions index.js
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class Replicate {
* @param {object} [options.params] - Query parameters
* @param {object|Headers} [options.headers] - HTTP headers
* @param {object} [options.data] - Body parameters
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the request
* @returns {Promise<Response>} - Resolves with the response object
* @throws {ApiError} If the request failed
*/
Expand All @@ -241,7 +242,7 @@ class Replicate {
);
}

const { method = "GET", params = {}, data } = options;
const { method = "GET", params = {}, data, signal } = options;

for (const [key, value] of Object.entries(params)) {
url.searchParams.append(key, value);
Expand Down Expand Up @@ -273,6 +274,7 @@ class Replicate {
method,
headers,
body,
signal,
};

const shouldRetry =
Expand Down Expand Up @@ -354,15 +356,20 @@ class Replicate {
* console.log(page);
* }
* @param {Function} endpoint - Function that returns a promise for the next page of results
* @param {object} [options]
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the request.
* @yields {object[]} Each page of results
*/
async *paginate(endpoint) {
async *paginate(endpoint, options = {}) {
const response = await endpoint();
yield response.results;
if (response.next) {
if (response.next && !(options.signal && options.signal.aborted)) {
const nextPage = () =>
this.request(response.next, { method: "GET" }).then((r) => r.json());
yield* this.paginate(nextPage);
this.request(response.next, {
method: "GET",
signal: options.signal,
}).then((r) => r.json());
yield* this.paginate(nextPage, options);
}
}

Expand Down
84 changes: 84 additions & 0 deletions index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,90 @@ describe("Replicate client", () => {
});
});

describe("paginate", () => {
test("pages through results", async () => {
nock(BASE_URL)
.get("/collections")
.reply(200, {
results: [
{
name: "Super resolution",
slug: "super-resolution",
description:
"Upscaling models that create high-quality images from low-quality images.",
},
],
next: `${BASE_URL}/collections?page=2`,
previous: null,
});
nock(BASE_URL)
.get("/collections?page=2")
.reply(200, {
results: [
{
name: "Image classification",
slug: "image-classification",
description: "Models that classify images.",
},
],
next: null,
previous: null,
});

const iterator = client.paginate(client.collections.list);

const firstPage = (await iterator.next()).value;
expect(firstPage.length).toBe(1);

const secondPage = (await iterator.next()).value;
expect(secondPage.length).toBe(1);
});

test("accepts an abort signal", async () => {
nock(BASE_URL)
.get("/collections")
.reply(200, {
results: [
{
name: "Super resolution",
slug: "super-resolution",
description:
"Upscaling models that create high-quality images from low-quality images.",
},
],
next: `${BASE_URL}/collections?page=2`,
previous: null,
});
nock(BASE_URL)
.get("/collections?page=2")
.reply(200, {
results: [
{
name: "Image classification",
slug: "image-classification",
description: "Models that classify images.",
},
],
next: null,
previous: null,
});

const controller = new AbortController();
const iterator = client.paginate(client.collections.list, {
signal: controller.signal,
});

const firstIteration = await iterator.next();
expect(firstIteration.value.length).toBe(1);

controller.abort();

const secondIteration = await iterator.next();
expect(secondIteration.value).toBeUndefined();
expect(secondIteration.done).toBe(true);
});
});

describe("account.get", () => {
test("Calls the correct API route", async () => {
nock(BASE_URL).get("/account").reply(200, {
Expand Down
8 changes: 4 additions & 4 deletions integration/next/pages/index.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export default () => (
<main>
<h1>Welcome to Next.js</h1>
</main>
)
<main>
<h1>Welcome to Next.js</h1>
</main>
);
5 changes: 4 additions & 1 deletion lib/accounts.js
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
/**
* Get the current account
*
* @param {object} [options]
* @param {AbortSignal} [options.signal] - An optional AbortSignal
* @returns {Promise<object>} Resolves with the current account
*/
async function getCurrentAccount() {
async function getCurrentAccount({ signal } = {}) {
const response = await this.request("/account", {
method: "GET",
signal,
});

return response.json();
Expand Down
Loading