Skip to content

Commit 3e756e8

Browse files
committed
Use "version" not "version" as the noun
We specify a model by the pattern `owner/model@version`. Conceptually, the model is the thing, not the version. The version is a way of specifying which version of the model we want, and doesn't make much sense in isolation (even if the API itself doesn't ask for the model (yet)).
1 parent ed327ef commit 3e756e8

File tree

6 files changed

+126
-41
lines changed

6 files changed

+126
-41
lines changed

README.md

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ To run a prediction and return its output:
2222
import replicate from "replicate";
2323

2424
const prediction = await replicate
25-
.version("db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf")
25+
.model(
26+
"stability-ai/stable-diffusion@db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
27+
)
2628
.predict({
2729
prompt: "painting of a cat by andy warhol",
2830
});
@@ -38,7 +40,9 @@ running, you can pass in an `onUpdate` callback function:
3840
import replicate from "replicate";
3941

4042
await replicate
41-
.version("db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf")
43+
.model(
44+
"stability-ai/stable-diffusion@db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
45+
)
4246
.predict(
4347
{
4448
prompt: "painting of a cat by andy warhol",
@@ -58,7 +62,9 @@ If you'd prefer to control your own polling you can use the low-level
5862
import replicate from "replicate";
5963

6064
const prediction = await replicate
61-
.version("db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf")
65+
.model(
66+
"stability-ai/stable-diffusion@db21e45d3f7023abc2a46ee38a23973f6dce16bb082a930b0c49861f96d1e5bf"
67+
)
6268
.createPrediction({
6369
prompt: "painting of a cat by andy warhol",
6470
});

lib/Version.js renamed to lib/Model.js

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,32 @@ import Prediction from "./Prediction.js";
33
import ReplicateObject from "./ReplicateObject.js";
44
import { sleep } from "./utils.js";
55

6-
export default class Version extends ReplicateObject {
7-
constructor({ id, ...rest }, client) {
6+
export default class Model extends ReplicateObject {
7+
constructor({ owner, name, version, id, ...rest }, client) {
88
super(rest, client);
99

10-
if (!id) {
11-
throw new ReplicateError("id is required");
10+
if (!owner) {
11+
throw new ReplicateError("owner is required");
1212
}
1313

14-
this.id = id;
14+
this.owner = owner;
15+
16+
if (!name) {
17+
throw new ReplicateError("name is required");
18+
}
19+
20+
this.name = name;
21+
22+
if (!(version || id)) {
23+
throw new ReplicateError("version is required");
24+
}
25+
26+
// We get `id` back from the API, but we want to call it `version` here.
27+
this.version = version || id;
28+
}
29+
30+
actionForGet() {
31+
return `GET /v1/models/${this.owner}/${this.name}/versions/${this.version}`;
1532
}
1633

1734
async predict(
@@ -68,11 +85,12 @@ export default class Version extends ReplicateObject {
6885

6986
async createPrediction(input) {
7087
// This is here and not on `Prediction` because conceptually, a prediction
71-
// from a version "belongs" to the version. It's an odd feature of the API
72-
// that the prediction creation isn't an action on the version, but we don't
73-
// need to expose that to users of this library.
88+
// from a model "belongs" to the model. It's an odd feature of the API that
89+
// the prediction creation isn't an action on the model (or that it doesn't
90+
// actually use the model information, only the version), but we don't need
91+
// to expose that to users of this library.
7492
const predictionData = await this.client.request("POST /v1/predictions", {
75-
version: this.id,
93+
version: this.version,
7694
input,
7795
});
7896

lib/Version.test.js renamed to lib/Model.test.js

Lines changed: 60 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,48 +7,83 @@ jest.unstable_mockModule("node-fetch", () => ({
77
}));
88

99
import { ReplicateResponseError } from "./errors.js";
10+
import Model from "./Model.js";
1011
import Prediction, { PredictionStatus } from "./Prediction.js";
1112

1213
const { default: ReplicateClient } = await import("./ReplicateClient.js");
1314

1415
let client;
15-
let version;
16+
let model;
1617

1718
beforeEach(() => {
1819
process.env.REPLICATE_API_TOKEN = "test-token-from-env";
1920

2021
client = new ReplicateClient({});
21-
version = client.version("test-version");
22+
model = client.model("test-owner/test-name@testversion");
23+
});
24+
25+
describe("load()", () => {
26+
it("makes request to get model version", async () => {
27+
jest.spyOn(client, "request").mockResolvedValue({
28+
id: "testversion",
29+
});
30+
31+
await model.load();
32+
33+
expect(client.request).toHaveBeenCalledWith(
34+
"GET /v1/models/test-owner/test-name/versions/testversion"
35+
);
36+
});
37+
38+
it("returns Model", async () => {
39+
jest.spyOn(client, "request").mockResolvedValue({
40+
id: "testversion",
41+
});
42+
43+
const returnedModel = await model.load();
44+
45+
expect(returnedModel).toBeInstanceOf(Model);
46+
});
47+
48+
it("updates Model in place", async () => {
49+
jest.spyOn(client, "request").mockResolvedValue({
50+
id: "testversion",
51+
});
52+
53+
const returnedModel = await model.load();
54+
55+
expect(returnedModel).toBe(model);
56+
});
2257
});
2358

2459
describe("predict()", () => {
2560
it("makes request to create prediction", async () => {
26-
jest.spyOn(version, "createPrediction").mockResolvedValue(
61+
jest.spyOn(model, "createPrediction").mockResolvedValue(
2762
new Prediction(
2863
{
29-
id: "test-prediction",
64+
id: "testprediction",
3065
status: PredictionStatus.SUCCEEDED,
3166
},
3267
client
3368
)
3469
);
3570

36-
await version.predict(
71+
await model.predict(
3772
{ text: "test text" },
3873
{},
3974
{ defaultPollingInterval: 0 }
4075
);
4176

42-
expect(version.createPrediction).toHaveBeenCalledWith({
77+
expect(model.createPrediction).toHaveBeenCalledWith({
4378
text: "test text",
4479
});
4580
});
4681

4782
it("uses created prediction's ID to fetch update", async () => {
48-
jest.spyOn(version, "createPrediction").mockResolvedValue(
83+
jest.spyOn(model, "createPrediction").mockResolvedValue(
4984
new Prediction(
5085
{
51-
id: "test-prediction",
86+
id: "testprediction",
5287
status: PredictionStatus.STARTING,
5388
},
5489
client
@@ -75,20 +110,20 @@ describe("predict()", () => {
75110
.spyOn(client, "request")
76111
.mockImplementation((action) => requestMockReturnValues[action]);
77112

78-
await version.predict(
113+
await model.predict(
79114
{ text: "test text" },
80115
{},
81116
{ defaultPollingInterval: 0 }
82117
);
83118

84-
expect(client.prediction).toHaveBeenCalledWith("test-prediction");
119+
expect(client.prediction).toHaveBeenCalledWith("testprediction");
85120
});
86121

87122
it("polls prediction status until success", async () => {
88-
jest.spyOn(version, "createPrediction").mockResolvedValue(
123+
jest.spyOn(model, "createPrediction").mockResolvedValue(
89124
new Prediction(
90125
{
91-
id: "test-prediction",
126+
id: "testprediction",
92127
status: PredictionStatus.STARTING,
93128
},
94129
client
@@ -98,21 +133,21 @@ describe("predict()", () => {
98133
const predictionLoadResults = [
99134
new Prediction(
100135
{
101-
id: "test-prediction",
136+
id: "testprediction",
102137
status: PredictionStatus.PROCESSING,
103138
},
104139
client
105140
),
106141
new Prediction(
107142
{
108-
id: "test-prediction",
143+
id: "testprediction",
109144
status: PredictionStatus.PROCESSING,
110145
},
111146
client
112147
),
113148
new Prediction(
114149
{
115-
id: "test-prediction",
150+
id: "testprediction",
116151
status: PredictionStatus.SUCCEEDED,
117152
},
118153
client
@@ -122,14 +157,14 @@ describe("predict()", () => {
122157
const predictionLoad = jest.fn(() => predictionLoadResults.shift());
123158

124159
jest.spyOn(client, "prediction").mockImplementation(() => {
125-
const prediction = new Prediction({ id: "test-prediction" }, client);
160+
const prediction = new Prediction({ id: "testprediction" }, client);
126161

127162
jest.spyOn(prediction, "load").mockImplementation(predictionLoad);
128163

129164
return prediction;
130165
});
131166

132-
const prediction = await version.predict(
167+
const prediction = await model.predict(
133168
{ text: "test text" },
134169
{},
135170
{ defaultPollingInterval: 0 }
@@ -140,10 +175,10 @@ describe("predict()", () => {
140175
});
141176

142177
it("retries polling on error", async () => {
143-
jest.spyOn(version, "createPrediction").mockResolvedValue(
178+
jest.spyOn(model, "createPrediction").mockResolvedValue(
144179
new Prediction(
145180
{
146-
id: "test-prediction",
181+
id: "testprediction",
147182
status: PredictionStatus.STARTING,
148183
},
149184
client
@@ -172,7 +207,7 @@ describe("predict()", () => {
172207
() =>
173208
new Prediction(
174209
{
175-
id: "test-prediction",
210+
id: "testprediction",
176211
status: PredictionStatus.SUCCEEDED,
177212
},
178213
client
@@ -182,15 +217,15 @@ describe("predict()", () => {
182217
const predictionLoad = jest.fn(() => predictionLoadResults.shift()());
183218

184219
jest.spyOn(client, "prediction").mockImplementation(() => {
185-
const prediction = new Prediction({ id: "test-prediction" }, client);
220+
const prediction = new Prediction({ id: "testprediction" }, client);
186221

187222
jest.spyOn(prediction, "load").mockImplementation(predictionLoad);
188223

189224
return prediction;
190225
});
191226
const backoffFn = jest.fn(() => 0);
192227

193-
const prediction = await version.predict(
228+
const prediction = await model.predict(
194229
{ text: "test text" },
195230
{},
196231
{ defaultPollingInterval: 0, backoffFn }
@@ -205,14 +240,14 @@ describe("predict()", () => {
205240
describe("createPrediction()", () => {
206241
it("makes request to create prediction", async () => {
207242
jest.spyOn(client, "request").mockResolvedValue({
208-
id: "test-prediction",
243+
id: "testprediction",
209244
status: PredictionStatus.SUCCEEDED,
210245
});
211246

212-
await version.createPrediction({ text: "test text" });
247+
await model.createPrediction({ text: "test text" });
213248

214249
expect(client.request).toHaveBeenCalledWith("POST /v1/predictions", {
215-
version: "test-version",
250+
version: "testversion",
216251
input: { text: "test text" },
217252
});
218253
});

lib/Prediction.test.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ describe("load()", () => {
4444
expect(returnedPrediction).toBeInstanceOf(Prediction);
4545
});
4646

47-
it("updates the prediction in place", async () => {
47+
it("updates Prediction in place", async () => {
4848
jest.spyOn(client, "request").mockResolvedValue({
4949
id: "testprediction",
5050
status: PredictionStatus.SUCCEEDED,

lib/ReplicateClient.js

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@ import {
55
ReplicateRequestError,
66
ReplicateResponseError,
77
} from "./errors.js";
8+
import Model from "./Model.js";
89
import Prediction from "./Prediction.js";
9-
import Version from "./Version.js";
1010

1111
export default class ReplicateClient {
1212
baseURL;
@@ -21,8 +21,16 @@ export default class ReplicateClient {
2121
}
2222
}
2323

24-
version(id) {
25-
return new Version({ id }, this);
24+
model(fullId) {
25+
const idComponents = /^([\w-]+)\/([\w-]+)@(\w+)$/.exec(fullId);
26+
27+
if (!idComponents) {
28+
throw new ReplicateError(`Invalid ID: ${fullId}`);
29+
}
30+
31+
const [, owner, name, version] = idComponents;
32+
33+
return new Model({ owner, name, version }, this);
2634
}
2735

2836
prediction(id) {

lib/ReplicateClient.test.js

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import {
1111
ReplicateRequestError,
1212
ReplicateResponseError,
1313
} from "./errors.js";
14+
import Model from "./Model.js";
15+
import Prediction from "./Prediction.js";
1416

1517
const { default: fetch } = await import("node-fetch");
1618
const { default: ReplicateClient } = await import("./ReplicateClient.js");
@@ -63,6 +65,22 @@ describe("constructor()", () => {
6365
});
6466
});
6567

68+
describe("model()", () => {
69+
it("returns Model", () => {
70+
const model = client.model("test-owner/test-name@testversion");
71+
72+
expect(model).toBeInstanceOf(Model);
73+
});
74+
});
75+
76+
describe("prediction()", () => {
77+
it("returns Prediction", () => {
78+
const prediction = client.prediction("testprediction");
79+
80+
expect(prediction).toBeInstanceOf(Prediction);
81+
});
82+
});
83+
6684
describe("request()", () => {
6785
it("throws ReplicateRequestError on failed fetch", async () => {
6886
fetch.mockImplementation(async () => {

0 commit comments

Comments
 (0)