diff --git a/index.d.ts b/index.d.ts index ab40ed8..69e651b 100644 --- a/index.d.ts +++ b/index.d.ts @@ -22,6 +22,23 @@ declare module "replicate" { models?: Model[]; } + export interface Deployment { + owner: string; + name: string; + current_release: { + number: number; + model: string; + version: string; + created_at: string; + created_by: Account; + configuration: { + hardware: string; + min_instances: number; + max_instances: number; + }; + }; + } + export interface Hardware { sku: string; name: string; @@ -173,6 +190,10 @@ declare module "replicate" { } ): Promise; }; + get( + deployment_owner: string, + deployment_name: string + ): Promise; }; hardware: { diff --git a/index.js b/index.js index 13207ca..83b9888 100644 --- a/index.js +++ b/index.js @@ -59,6 +59,7 @@ class Replicate { }; this.deployments = { + get: deployments.get.bind(this), predictions: { create: deployments.predictions.create.bind(this), }, diff --git a/index.test.ts b/index.test.ts index 106cc58..ae01338 100644 --- a/index.test.ts +++ b/index.test.ts @@ -72,7 +72,7 @@ describe("Replicate client", () => { }); }); - describe("accounts.current", () => { + describe("account.get", () => { test("Calls the correct API route", async () => { nock(BASE_URL).get("/account").reply(200, { type: "organization", @@ -721,6 +721,47 @@ describe("Replicate client", () => { // Add more tests for error handling, edge cases, etc. }); + describe("deployments.get", () => { + test("Calls the correct API route with the correct payload", async () => { + nock(BASE_URL) + .get("/deployments/acme/my-app-image-generator") + .reply(200, { + owner: "acme", + name: "my-app-image-generator", + current_release: { + number: 1, + model: "stability-ai/sdxl", + version: + "da77bc59ee60423279fd632efb4795ab731d9e3ca9705ef3341091fb989b7eaf", + created_at: "2024-02-15T16:32:57.018467Z", + created_by: { + type: "organization", + username: "acme", + name: "Acme Corp, Inc.", + github_url: "https://github.com/acme", + }, + configuration: { + hardware: "gpu-t4", + scaling: { + min_instances: 1, + max_instances: 5, + }, + }, + }, + }); + + const deployment = await client.deployments.get( + "acme", + "my-app-image-generator" + ); + + expect(deployment.owner).toBe("acme"); + expect(deployment.name).toBe("my-app-image-generator"); + expect(deployment.current_release.model).toBe("stability-ai/sdxl"); + }); + // Add more tests for error handling, edge cases, etc. + }); + describe("predictions.create with model", () => { test("Calls the correct API route with the correct payload", async () => { nock(BASE_URL) diff --git a/lib/deployments.js b/lib/deployments.js index 6f32cdb..8ba5ea3 100644 --- a/lib/deployments.js +++ b/lib/deployments.js @@ -33,8 +33,27 @@ async function createPrediction(deployment_owner, deployment_name, options) { return response.json(); } +/** + * Get a deployment + * + * @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment + * @param {string} deployment_name - Required. The name of the deployment + * @returns {Promise} Resolves with the deployment data + */ +async function getDeployment(deployment_owner, deployment_name) { + const response = await this.request( + `/deployments/${deployment_owner}/${deployment_name}`, + { + method: "GET", + } + ); + + return response.json(); +} + module.exports = { predictions: { create: createPrediction, }, + get: getDeployment, };