@@ -7,48 +7,83 @@ jest.unstable_mockModule("node-fetch", () => ({
7
7
} ) ) ;
8
8
9
9
import { ReplicateResponseError } from "./errors.js" ;
10
+ import Model from "./Model.js" ;
10
11
import Prediction , { PredictionStatus } from "./Prediction.js" ;
11
12
12
13
const { default : ReplicateClient } = await import ( "./ReplicateClient.js" ) ;
13
14
14
15
let client ;
15
- let version ;
16
+ let model ;
16
17
17
18
beforeEach ( ( ) => {
18
19
process . env . REPLICATE_API_TOKEN = "test-token-from-env" ;
19
20
20
21
client = new ReplicateClient ( { } ) ;
21
- version = client . version ( "test-version" ) ;
22
+ model = client . model ( "test-owner/test-name@testversion" ) ;
23
+ } ) ;
24
+
25
+ describe ( "load()" , ( ) => {
26
+ it ( "makes request to get model version" , async ( ) => {
27
+ jest . spyOn ( client , "request" ) . mockResolvedValue ( {
28
+ id : "testversion" ,
29
+ } ) ;
30
+
31
+ await model . load ( ) ;
32
+
33
+ expect ( client . request ) . toHaveBeenCalledWith (
34
+ "GET /v1/models/test-owner/test-name/versions/testversion"
35
+ ) ;
36
+ } ) ;
37
+
38
+ it ( "returns Model" , async ( ) => {
39
+ jest . spyOn ( client , "request" ) . mockResolvedValue ( {
40
+ id : "testversion" ,
41
+ } ) ;
42
+
43
+ const returnedModel = await model . load ( ) ;
44
+
45
+ expect ( returnedModel ) . toBeInstanceOf ( Model ) ;
46
+ } ) ;
47
+
48
+ it ( "updates Model in place" , async ( ) => {
49
+ jest . spyOn ( client , "request" ) . mockResolvedValue ( {
50
+ id : "testversion" ,
51
+ } ) ;
52
+
53
+ const returnedModel = await model . load ( ) ;
54
+
55
+ expect ( returnedModel ) . toBe ( model ) ;
56
+ } ) ;
22
57
} ) ;
23
58
24
59
describe ( "predict()" , ( ) => {
25
60
it ( "makes request to create prediction" , async ( ) => {
26
- jest . spyOn ( version , "createPrediction" ) . mockResolvedValue (
61
+ jest . spyOn ( model , "createPrediction" ) . mockResolvedValue (
27
62
new Prediction (
28
63
{
29
- id : "test-prediction " ,
64
+ id : "testprediction " ,
30
65
status : PredictionStatus . SUCCEEDED ,
31
66
} ,
32
67
client
33
68
)
34
69
) ;
35
70
36
- await version . predict (
71
+ await model . predict (
37
72
{ text : "test text" } ,
38
73
{ } ,
39
74
{ defaultPollingInterval : 0 }
40
75
) ;
41
76
42
- expect ( version . createPrediction ) . toHaveBeenCalledWith ( {
77
+ expect ( model . createPrediction ) . toHaveBeenCalledWith ( {
43
78
text : "test text" ,
44
79
} ) ;
45
80
} ) ;
46
81
47
82
it ( "uses created prediction's ID to fetch update" , async ( ) => {
48
- jest . spyOn ( version , "createPrediction" ) . mockResolvedValue (
83
+ jest . spyOn ( model , "createPrediction" ) . mockResolvedValue (
49
84
new Prediction (
50
85
{
51
- id : "test-prediction " ,
86
+ id : "testprediction " ,
52
87
status : PredictionStatus . STARTING ,
53
88
} ,
54
89
client
@@ -75,20 +110,20 @@ describe("predict()", () => {
75
110
. spyOn ( client , "request" )
76
111
. mockImplementation ( ( action ) => requestMockReturnValues [ action ] ) ;
77
112
78
- await version . predict (
113
+ await model . predict (
79
114
{ text : "test text" } ,
80
115
{ } ,
81
116
{ defaultPollingInterval : 0 }
82
117
) ;
83
118
84
- expect ( client . prediction ) . toHaveBeenCalledWith ( "test-prediction " ) ;
119
+ expect ( client . prediction ) . toHaveBeenCalledWith ( "testprediction " ) ;
85
120
} ) ;
86
121
87
122
it ( "polls prediction status until success" , async ( ) => {
88
- jest . spyOn ( version , "createPrediction" ) . mockResolvedValue (
123
+ jest . spyOn ( model , "createPrediction" ) . mockResolvedValue (
89
124
new Prediction (
90
125
{
91
- id : "test-prediction " ,
126
+ id : "testprediction " ,
92
127
status : PredictionStatus . STARTING ,
93
128
} ,
94
129
client
@@ -98,21 +133,21 @@ describe("predict()", () => {
98
133
const predictionLoadResults = [
99
134
new Prediction (
100
135
{
101
- id : "test-prediction " ,
136
+ id : "testprediction " ,
102
137
status : PredictionStatus . PROCESSING ,
103
138
} ,
104
139
client
105
140
) ,
106
141
new Prediction (
107
142
{
108
- id : "test-prediction " ,
143
+ id : "testprediction " ,
109
144
status : PredictionStatus . PROCESSING ,
110
145
} ,
111
146
client
112
147
) ,
113
148
new Prediction (
114
149
{
115
- id : "test-prediction " ,
150
+ id : "testprediction " ,
116
151
status : PredictionStatus . SUCCEEDED ,
117
152
} ,
118
153
client
@@ -122,14 +157,14 @@ describe("predict()", () => {
122
157
const predictionLoad = jest . fn ( ( ) => predictionLoadResults . shift ( ) ) ;
123
158
124
159
jest . spyOn ( client , "prediction" ) . mockImplementation ( ( ) => {
125
- const prediction = new Prediction ( { id : "test-prediction " } , client ) ;
160
+ const prediction = new Prediction ( { id : "testprediction " } , client ) ;
126
161
127
162
jest . spyOn ( prediction , "load" ) . mockImplementation ( predictionLoad ) ;
128
163
129
164
return prediction ;
130
165
} ) ;
131
166
132
- const prediction = await version . predict (
167
+ const prediction = await model . predict (
133
168
{ text : "test text" } ,
134
169
{ } ,
135
170
{ defaultPollingInterval : 0 }
@@ -140,10 +175,10 @@ describe("predict()", () => {
140
175
} ) ;
141
176
142
177
it ( "retries polling on error" , async ( ) => {
143
- jest . spyOn ( version , "createPrediction" ) . mockResolvedValue (
178
+ jest . spyOn ( model , "createPrediction" ) . mockResolvedValue (
144
179
new Prediction (
145
180
{
146
- id : "test-prediction " ,
181
+ id : "testprediction " ,
147
182
status : PredictionStatus . STARTING ,
148
183
} ,
149
184
client
@@ -172,7 +207,7 @@ describe("predict()", () => {
172
207
( ) =>
173
208
new Prediction (
174
209
{
175
- id : "test-prediction " ,
210
+ id : "testprediction " ,
176
211
status : PredictionStatus . SUCCEEDED ,
177
212
} ,
178
213
client
@@ -182,15 +217,15 @@ describe("predict()", () => {
182
217
const predictionLoad = jest . fn ( ( ) => predictionLoadResults . shift ( ) ( ) ) ;
183
218
184
219
jest . spyOn ( client , "prediction" ) . mockImplementation ( ( ) => {
185
- const prediction = new Prediction ( { id : "test-prediction " } , client ) ;
220
+ const prediction = new Prediction ( { id : "testprediction " } , client ) ;
186
221
187
222
jest . spyOn ( prediction , "load" ) . mockImplementation ( predictionLoad ) ;
188
223
189
224
return prediction ;
190
225
} ) ;
191
226
const backoffFn = jest . fn ( ( ) => 0 ) ;
192
227
193
- const prediction = await version . predict (
228
+ const prediction = await model . predict (
194
229
{ text : "test text" } ,
195
230
{ } ,
196
231
{ defaultPollingInterval : 0 , backoffFn }
@@ -205,14 +240,14 @@ describe("predict()", () => {
205
240
describe ( "createPrediction()" , ( ) => {
206
241
it ( "makes request to create prediction" , async ( ) => {
207
242
jest . spyOn ( client , "request" ) . mockResolvedValue ( {
208
- id : "test-prediction " ,
243
+ id : "testprediction " ,
209
244
status : PredictionStatus . SUCCEEDED ,
210
245
} ) ;
211
246
212
- await version . createPrediction ( { text : "test text" } ) ;
247
+ await model . createPrediction ( { text : "test text" } ) ;
213
248
214
249
expect ( client . request ) . toHaveBeenCalledWith ( "POST /v1/predictions" , {
215
- version : "test-version " ,
250
+ version : "testversion " ,
216
251
input : { text : "test text" } ,
217
252
} ) ;
218
253
} ) ;
0 commit comments