Skip to content

Commit 8fd6c08

Browse files
authored
Allow replicate.predictions.create to accept version or model options (#170)
* Allow replicate.predictions.create to accept version or model options * Remove models.predictions.create method
1 parent cbde2c1 commit 8fd6c08

File tree

5 files changed

+49
-77
lines changed

5 files changed

+49
-77
lines changed

index.d.ts

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -188,28 +188,19 @@ declare module "replicate" {
188188
version_id: string
189189
): Promise<ModelVersion>;
190190
};
191-
predictions: {
192-
create(
193-
model_owner: string,
194-
model_name: string,
195-
options: {
196-
input: object;
197-
stream?: boolean;
198-
webhook?: string;
199-
webhook_events_filter?: WebhookEventType[];
200-
}
201-
): Promise<Prediction>;
202-
};
203191
};
204192

205193
predictions: {
206-
create(options: {
207-
version: string;
208-
input: object;
209-
stream?: boolean;
210-
webhook?: string;
211-
webhook_events_filter?: WebhookEventType[];
212-
}): Promise<Prediction>;
194+
create(
195+
options: {
196+
model?: string;
197+
version?: string;
198+
input: object;
199+
stream?: boolean;
200+
webhook?: string;
201+
webhook_events_filter?: WebhookEventType[];
202+
} & ({ version: string } | { model: string })
203+
): Promise<Prediction>;
213204
get(prediction_id: string): Promise<Prediction>;
214205
cancel(prediction_id: string): Promise<Prediction>;
215206
list(): Promise<Page<Prediction>>;

index.js

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,6 @@ class Replicate {
7070
list: models.versions.list.bind(this),
7171
get: models.versions.get.bind(this),
7272
},
73-
predictions: {
74-
create: models.predictions.create.bind(this),
75-
},
7673
};
7774

7875
this.predictions = {
@@ -117,12 +114,13 @@ class Replicate {
117114
...data,
118115
version: identifier.version,
119116
});
117+
} else if (identifier.owner && identifier.name) {
118+
prediction = await this.predictions.create({
119+
...data,
120+
model: `${identifier.owner}/${identifier.name}`,
121+
});
120122
} else {
121-
prediction = await this.models.predictions.create(
122-
identifier.owner,
123-
identifier.name,
124-
data
125-
);
123+
throw new Error("Invalid model version identifier");
126124
}
127125

128126
// Call progress callback with the initial prediction object
@@ -260,12 +258,14 @@ class Replicate {
260258
version: identifier.version,
261259
stream: true,
262260
});
261+
} else if (identifier.owner && identifier.name) {
262+
prediction = await this.predictions.create({
263+
...data,
264+
model: `${identifier.owner}/${identifier.name}`,
265+
stream: true,
266+
});
263267
} else {
264-
prediction = await this.models.predictions.create(
265-
identifier.owner,
266-
identifier.name,
267-
{ ...data, stream: true }
268-
);
268+
throw new Error("Invalid model version identifier");
269269
}
270270

271271
if (prediction.urls && prediction.urls.stream) {

index.test.ts

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ describe("Replicate client", () => {
700700
// Add more tests for error handling, edge cases, etc.
701701
});
702702

703-
describe("models.predictions.create", () => {
703+
describe("predictions.create with model", () => {
704704
test("Calls the correct API route with the correct payload", async () => {
705705
nock(BASE_URL)
706706
.post("/models/meta/llama-2-70b-chat/predictions")
@@ -721,17 +721,14 @@ describe("Replicate client", () => {
721721
get: "https://api.replicate.com/v1/predictions/heat2o3bzn3ahtr6bjfftvbaci",
722722
},
723723
});
724-
const prediction = await client.models.predictions.create(
725-
"meta",
726-
"llama-2-70b-chat",
727-
{
728-
input: {
729-
prompt: "Please write a haiku about llamas.",
730-
},
731-
webhook: "http://test.host/webhook",
732-
webhook_events_filter: ["output", "completed"],
733-
}
734-
);
724+
const prediction = await client.predictions.create({
725+
model: "meta/llama-2-70b-chat",
726+
input: {
727+
prompt: "Please write a haiku about llamas.",
728+
},
729+
webhook: "http://test.host/webhook",
730+
webhook_events_filter: ["output", "completed"],
731+
});
735732
expect(prediction.id).toBe("heat2o3bzn3ahtr6bjfftvbaci");
736733
});
737734
// Add more tests for error handling, edge cases, etc.

lib/models.js

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -89,36 +89,9 @@ async function createModel(model_owner, model_name, options) {
8989
return response.json();
9090
}
9191

92-
/**
93-
* Create a new prediction
94-
*
95-
* @param {string} model_owner - Required. The name of the user or organization that owns the model
96-
* @param {string} model_name - Required. The name of the model
97-
* @param {object} options
98-
* @param {object} options.input - Required. An object with the model inputs
99-
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
100-
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
101-
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
102-
* @returns {Promise<object>} Resolves with the created prediction
103-
*/
104-
async function createPrediction(model_owner, model_name, options) {
105-
const { stream, ...data } = options;
106-
107-
const response = await this.request(
108-
`/models/${model_owner}/${model_name}/predictions`,
109-
{
110-
method: "POST",
111-
data: { ...data, stream },
112-
}
113-
);
114-
115-
return response.json();
116-
}
117-
11892
module.exports = {
11993
get: getModel,
12094
list: listModels,
12195
create: createModel,
12296
versions: { list: listModelVersions, get: getModelVersion },
123-
predictions: { create: createPrediction },
12497
};

lib/predictions.js

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
* Create a new prediction
33
*
44
* @param {object} options
5-
* @param {string} options.version - Required. The model version
5+
* @param {string} options.model - The model.
6+
* @param {string} options.version - The model version.
67
* @param {object} options.input - Required. An object with the model inputs
78
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
89
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
910
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
1011
* @returns {Promise<object>} Resolves with the created prediction
1112
*/
1213
async function createPrediction(options) {
13-
const { stream, ...data } = options;
14+
const { model, version, stream, ...data } = options;
1415

1516
if (data.webhook) {
1617
try {
@@ -21,10 +22,20 @@ async function createPrediction(options) {
2122
}
2223
}
2324

24-
const response = await this.request("/predictions", {
25-
method: "POST",
26-
data: { ...data, stream },
27-
});
25+
let response;
26+
if (version) {
27+
response = await this.request("/predictions", {
28+
method: "POST",
29+
data: { ...data, stream, version },
30+
});
31+
} else if (model) {
32+
response = await this.request(`/models/${model}/predictions`, {
33+
method: "POST",
34+
data: { ...data, stream },
35+
});
36+
} else {
37+
throw new Error("Either model or version must be specified");
38+
}
2839

2940
return response.json();
3041
}

0 commit comments

Comments
 (0)