@@ -185,42 +185,84 @@ describe("Replicate client", () => {
185
185
} ) ;
186
186
187
187
describe ( "predictions.create" , ( ) => {
188
- test ( "Calls the correct API route with the correct payload" , async ( ) => {
189
- nock ( BASE_URL )
190
- . post ( "/predictions" )
191
- . reply ( 200 , {
192
- id : "ufawqhfynnddngldkgtslldrkq" ,
193
- model : "replicate/hello-world" ,
194
- version :
195
- "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
196
- urls : {
197
- get : "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq" ,
198
- cancel :
199
- "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" ,
200
- } ,
201
- created_at : "2022-04-26T22:13:06.224088Z" ,
202
- started_at : null ,
203
- completed_at : null ,
204
- status : "starting" ,
205
- input : {
206
- text : "Alice" ,
188
+ const predictionTestCases = [
189
+ {
190
+ description : "String input" ,
191
+ input : {
192
+ text : "Alice" ,
193
+ } ,
194
+ } ,
195
+ {
196
+ description : "Number input" ,
197
+ input : {
198
+ text : 123 ,
199
+ } ,
200
+ } ,
201
+ {
202
+ description : "Boolean input" ,
203
+ input : {
204
+ text : true ,
205
+ } ,
206
+ } ,
207
+ {
208
+ description : "Array input" ,
209
+ input : {
210
+ text : [ "Alice" , "Bob" , "Charlie" ] ,
211
+ } ,
212
+ } ,
213
+ {
214
+ description : "Object input" ,
215
+ input : {
216
+ text : {
217
+ name : "Alice" ,
207
218
} ,
208
- output : null ,
209
- error : null ,
210
- logs : null ,
211
- metrics : { } ,
212
- } ) ;
213
- const prediction = await client . predictions . create ( {
219
+ } ,
220
+ } ,
221
+ ] . map ( ( testCase ) => ( {
222
+ ...testCase ,
223
+ expectedResponse : {
224
+ id : "ufawqhfynnddngldkgtslldrkq" ,
225
+ model : "replicate/hello-world" ,
214
226
version :
215
227
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
216
- input : {
217
- text : "Alice" ,
228
+ urls : {
229
+ get : "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq" ,
230
+ cancel :
231
+ "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel" ,
218
232
} ,
219
- webhook : "http://test.host/webhook" ,
220
- webhook_events_filter : [ "output" , "completed" ] ,
221
- } ) ;
222
- expect ( prediction . id ) . toBe ( "ufawqhfynnddngldkgtslldrkq" ) ;
223
- } ) ;
233
+ input : testCase . input ,
234
+ created_at : "2022-04-26T22:13:06.224088Z" ,
235
+ started_at : null ,
236
+ completed_at : null ,
237
+ status : "starting" ,
238
+ } ,
239
+ } ) ) ;
240
+
241
+ test . each ( predictionTestCases ) (
242
+ "$description" ,
243
+ async ( { input, expectedResponse } ) => {
244
+ nock ( BASE_URL )
245
+ . post ( "/predictions" , {
246
+ version :
247
+ "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
248
+ input : input as Record < string , any > ,
249
+ webhook : "http://test.host/webhook" ,
250
+ webhook_events_filter : [ "output" , "completed" ] ,
251
+ } )
252
+ . reply ( 200 , expectedResponse ) ;
253
+
254
+ const response = await client . predictions . create ( {
255
+ version :
256
+ "5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" ,
257
+ input : input as Record < string , any > ,
258
+ webhook : "http://test.host/webhook" ,
259
+ webhook_events_filter : [ "output" , "completed" ] ,
260
+ } ) ;
261
+
262
+ expect ( response . input ) . toEqual ( input ) ;
263
+ expect ( response . status ) . toBe ( expectedResponse . status ) ;
264
+ }
265
+ ) ;
224
266
225
267
const fileTestCases = [
226
268
// Skip test case if File type is not available
0 commit comments