Skip to content

Commit 47f8b01

Browse files
authored
Add support for models.create and hardware.list endpoints (#153)
* Add support for hardware.list endpoint * Add support for models.create endpoint * Update README * Update test expectation
1 parent b14eeaa commit 47f8b01

File tree

6 files changed

+180
-24
lines changed

6 files changed

+180
-24
lines changed

README.md

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,43 @@ const response = await replicate.models.list();
236236
}
237237
```
238238

239+
### `replicate.models.create`
240+
241+
Create a new public or private model.
242+
243+
```js
244+
const response = await replicate.models.create(model_owner, model_name, options);
245+
```
246+
247+
| name | type | description |
248+
| ------------------------- | ------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
249+
| `model_owner` | string | **Required**. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization. |
250+
| `model_name` | string | **Required**. The name of the model. This must be unique among all models owned by the user or organization. |
251+
| `options.visibility` | string | **Required**. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model. |
252+
| `options.hardware` | string | **Required**. The SKU for the hardware used to run the model. Possible values can be found by calling [`replicate.hardware.list()](#replicatehardwarelist)`. |
253+
| `options.description` | string | A description of the model. |
254+
| `options.github_url` | string | A URL for the model's source code on GitHub. |
255+
| `options.paper_url` | string | A URL for the model's paper. |
256+
| `options.license_url` | string | A URL for the model's license. |
257+
| `options.cover_image_url` | string | A URL for the model's cover image. This should be an image file. |
258+
259+
### `replicate.hardware.list`
260+
261+
List available hardware for running models on Replicate.
262+
263+
```js
264+
const response = await replicate.hardware.list()
265+
```
266+
267+
```jsonc
268+
[
269+
{"name": "CPU", "sku": "cpu" },
270+
{"name": "Nvidia T4 GPU", "sku": "gpu-t4" },
271+
{"name": "Nvidia A40 GPU", "sku": "gpu-a40-small" },
272+
{"name": "Nvidia A40 (Large) GPU", "sku": "gpu-a40-large" },
273+
]
274+
```
275+
239276
### `replicate.models.versions.list`
240277

241278
Get a list of all published versions of a model, including input and output schemas for each version.

index.d.ts

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
declare module 'replicate' {
22
type Status = 'starting' | 'processing' | 'succeeded' | 'failed' | 'canceled';
3+
type Visibility = 'public' | 'private';
34
type WebhookEventType = 'start' | 'output' | 'logs' | 'completed';
45

56
export interface ApiError extends Error {
@@ -14,6 +15,11 @@ declare module 'replicate' {
1415
models?: Model[];
1516
}
1617

18+
export interface Hardware {
19+
sku: string;
20+
name: string
21+
}
22+
1723
export interface Model {
1824
url: string;
1925
owner: string;
@@ -115,9 +121,40 @@ declare module 'replicate' {
115121
get(collection_slug: string): Promise<Collection>;
116122
};
117123

124+
deployments: {
125+
predictions: {
126+
create(
127+
deployment_owner: string,
128+
deployment_name: string,
129+
options: {
130+
input: object;
131+
stream?: boolean;
132+
webhook?: string;
133+
webhook_events_filter?: WebhookEventType[];
134+
}
135+
): Promise<Prediction>;
136+
};
137+
};
138+
139+
hardware: {
140+
list(): Promise<Hardware[]>
141+
}
142+
118143
models: {
119144
get(model_owner: string, model_name: string): Promise<Model>;
120145
list(): Promise<Page<Model>>;
146+
create(
147+
model_owner: string,
148+
model_name: string,
149+
options: {
150+
visibility: Visibility;
151+
hardware: string;
152+
description?: string;
153+
github_url?: string;
154+
paper_url?: string;
155+
license_url?: string;
156+
cover_image_url?: string;
157+
}): Promise<Model>;
121158
versions: {
122159
list(model_owner: string, model_name: string): Promise<ModelVersion[]>;
123160
get(
@@ -157,20 +194,5 @@ declare module 'replicate' {
157194
cancel(training_id: string): Promise<Training>;
158195
list(): Promise<Page<Training>>;
159196
};
160-
161-
deployments: {
162-
predictions: {
163-
create(
164-
deployment_owner: string,
165-
deployment_name: string,
166-
options: {
167-
input: object;
168-
stream?: boolean;
169-
webhook?: string;
170-
webhook_events_filter?: WebhookEventType[];
171-
}
172-
): Promise<Prediction>;
173-
};
174-
};
175197
}
176198
}

index.js

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ const { withAutomaticRetries } = require('./lib/util');
33

44
const collections = require('./lib/collections');
55
const deployments = require('./lib/deployments');
6+
const hardware = require('./lib/hardware');
67
const models = require('./lib/models');
78
const predictions = require('./lib/predictions');
89
const trainings = require('./lib/trainings');
@@ -49,9 +50,20 @@ class Replicate {
4950
get: collections.get.bind(this),
5051
};
5152

53+
this.deployments = {
54+
predictions: {
55+
create: deployments.predictions.create.bind(this),
56+
}
57+
};
58+
59+
this.hardware = {
60+
list: hardware.list.bind(this),
61+
};
62+
5263
this.models = {
5364
get: models.get.bind(this),
5465
list: models.list.bind(this),
66+
create: models.create.bind(this),
5567
versions: {
5668
list: models.versions.list.bind(this),
5769
get: models.versions.get.bind(this),
@@ -71,12 +83,6 @@ class Replicate {
7183
cancel: trainings.cancel.bind(this),
7284
list: trainings.list.bind(this),
7385
};
74-
75-
this.deployments = {
76-
predictions: {
77-
create: deployments.predictions.create.bind(this),
78-
}
79-
};
8086
}
8187

8288
/**

index.test.ts

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,20 +136,20 @@ describe('Replicate client', () => {
136136
nock(BASE_URL)
137137
.get('/models')
138138
.reply(200, {
139-
results: [{ url: 'https://replicate.com/some-user/model-1' }],
139+
results: [ { url: 'https://replicate.com/some-user/model-1' } ],
140140
next: 'https://api.replicate.com/v1/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw',
141141
})
142142
.get('/models?cursor=cD0yMDIyLTAxLTIxKzIzJTNBMTglM0EyNC41MzAzNTclMkIwMCUzQTAw')
143143
.reply(200, {
144-
results: [{ url: 'https://replicate.com/some-user/model-2' }],
144+
results: [ { url: 'https://replicate.com/some-user/model-2' } ],
145145
next: null,
146146
});
147147

148148
const results: Model[] = [];
149149
for await (const batch of client.paginate(client.models.list)) {
150150
results.push(...batch);
151151
}
152-
expect(results).toEqual([{ url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' }]);
152+
expect(results).toEqual([ { url: 'https://replicate.com/some-user/model-1' }, { url: 'https://replicate.com/some-user/model-2' } ]);
153153

154154
// Add more tests for error handling, edge cases, etc.
155155
});
@@ -662,6 +662,54 @@ describe('Replicate client', () => {
662662
// Add more tests for error handling, edge cases, etc.
663663
});
664664

665+
describe('hardware.list', () => {
666+
test('Calls the correct API route', async () => {
667+
nock(BASE_URL)
668+
.get('/hardware')
669+
.reply(200, [
670+
{ name: "CPU", sku: "cpu" },
671+
{ name: "Nvidia T4 GPU", sku: "gpu-t4" },
672+
{ name: "Nvidia A40 GPU", sku: "gpu-a40-small" },
673+
{ name: "Nvidia A40 (Large) GPU", sku: "gpu-a40-large" },
674+
]);
675+
676+
const hardware = await client.hardware.list();
677+
expect(hardware.length).toBe(4);
678+
expect(hardware[ 0 ].name).toBe('CPU');
679+
expect(hardware[ 0 ].sku).toBe('cpu');
680+
});
681+
// Add more tests for error handling, edge cases, etc.
682+
});
683+
684+
describe('models.create', () => {
685+
test('Calls the correct API route with the correct payload', async () => {
686+
nock(BASE_URL)
687+
.post('/models')
688+
.reply(200, {
689+
owner: 'test-owner',
690+
name: 'test-model',
691+
visibility: 'public',
692+
hardware: 'cpu',
693+
description: 'A test model',
694+
});
695+
696+
const model = await client.models.create(
697+
'test-owner',
698+
'test-model',
699+
{
700+
visibility: 'public',
701+
hardware: 'cpu',
702+
description: 'A test model',
703+
});
704+
705+
expect(model.owner).toBe('test-owner');
706+
expect(model.name).toBe('test-model');
707+
expect(model.visibility).toBe('public');
708+
// expect(model.hardware).toBe('cpu');
709+
expect(model.description).toBe('A test model');
710+
});
711+
});
712+
665713
describe('run', () => {
666714
test('Calls the correct API routes', async () => {
667715
let firstPollingRequest = true;

lib/hardware.js

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
/**
2+
* List hardware
3+
*
4+
* @returns {Promise<object[]>} Resolves with the array of hardware
5+
*/
6+
async function listHardware() {
7+
const response = await this.request('/hardware', {
8+
method: 'GET',
9+
});
10+
11+
return response.json();
12+
}
13+
14+
module.exports = {
15+
list: listHardware,
16+
};

lib/models.js

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,35 @@ async function listModels() {
5757
return response.json();
5858
}
5959

60+
/**
61+
* Create a new model
62+
*
63+
* @param {string} model_owner - Required. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization.
64+
* @param {string} model_name - Required. The name of the model. This must be unique among all models owned by the user or organization.
65+
* @param {object} options
66+
* @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model.
67+
* @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`.
68+
* @param {string} options.description - A description of the model.
69+
* @param {string} options.github_url - A URL for the model's source code on GitHub.
70+
* @param {string} options.paper_url - A URL for the model's paper.
71+
* @param {string} options.license_url - A URL for the model's license.
72+
* @param {string} options.cover_image_url - A URL for the model's cover image. This should be an image file.
73+
* @returns {Promise<object>} Resolves with the model version data
74+
*/
75+
async function createModel(model_owner, model_name, options) {
76+
const data = { owner: model_owner, name: model_name, ...options };
77+
78+
const response = await this.request('/models', {
79+
method: 'POST',
80+
data,
81+
});
82+
83+
return response.json();
84+
}
85+
6086
module.exports = {
6187
get: getModel,
6288
list: listModels,
89+
create: createModel,
6390
versions: { list: listModelVersions, get: getModelVersion },
6491
};

0 commit comments

Comments
 (0)