Skip to content

Commit d981fc1

Browse files
authored
Fix regression in how array input values are transformed (#266)
* Fix regression in how array input values are transformed * Refactor tests
1 parent e059286 commit d981fc1

File tree

2 files changed

+77
-34
lines changed

2 files changed

+77
-34
lines changed

index.test.ts

Lines changed: 74 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -185,42 +185,84 @@ describe("Replicate client", () => {
185185
});
186186

187187
describe("predictions.create", () => {
188-
test("Calls the correct API route with the correct payload", async () => {
189-
nock(BASE_URL)
190-
.post("/predictions")
191-
.reply(200, {
192-
id: "ufawqhfynnddngldkgtslldrkq",
193-
model: "replicate/hello-world",
194-
version:
195-
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
196-
urls: {
197-
get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
198-
cancel:
199-
"https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel",
200-
},
201-
created_at: "2022-04-26T22:13:06.224088Z",
202-
started_at: null,
203-
completed_at: null,
204-
status: "starting",
205-
input: {
206-
text: "Alice",
188+
const predictionTestCases = [
189+
{
190+
description: "String input",
191+
input: {
192+
text: "Alice",
193+
},
194+
},
195+
{
196+
description: "Number input",
197+
input: {
198+
text: 123,
199+
},
200+
},
201+
{
202+
description: "Boolean input",
203+
input: {
204+
text: true,
205+
},
206+
},
207+
{
208+
description: "Array input",
209+
input: {
210+
text: ["Alice", "Bob", "Charlie"],
211+
},
212+
},
213+
{
214+
description: "Object input",
215+
input: {
216+
text: {
217+
name: "Alice",
207218
},
208-
output: null,
209-
error: null,
210-
logs: null,
211-
metrics: {},
212-
});
213-
const prediction = await client.predictions.create({
219+
},
220+
},
221+
].map((testCase) => ({
222+
...testCase,
223+
expectedResponse: {
224+
id: "ufawqhfynnddngldkgtslldrkq",
225+
model: "replicate/hello-world",
214226
version:
215227
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
216-
input: {
217-
text: "Alice",
228+
urls: {
229+
get: "https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq",
230+
cancel:
231+
"https://api.replicate.com/v1/predictions/ufawqhfynnddngldkgtslldrkq/cancel",
218232
},
219-
webhook: "http://test.host/webhook",
220-
webhook_events_filter: ["output", "completed"],
221-
});
222-
expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq");
223-
});
233+
input: testCase.input,
234+
created_at: "2022-04-26T22:13:06.224088Z",
235+
started_at: null,
236+
completed_at: null,
237+
status: "starting",
238+
},
239+
}));
240+
241+
test.each(predictionTestCases)(
242+
"$description",
243+
async ({ input, expectedResponse }) => {
244+
nock(BASE_URL)
245+
.post("/predictions", {
246+
version:
247+
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
248+
input: input as Record<string, any>,
249+
webhook: "http://test.host/webhook",
250+
webhook_events_filter: ["output", "completed"],
251+
})
252+
.reply(200, expectedResponse);
253+
254+
const response = await client.predictions.create({
255+
version:
256+
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
257+
input: input as Record<string, any>,
258+
webhook: "http://test.host/webhook",
259+
webhook_events_filter: ["output", "completed"],
260+
});
261+
262+
expect(response.input).toEqual(input);
263+
expect(response.status).toBe(expectedResponse.status);
264+
}
265+
);
224266

225267
const fileTestCases = [
226268
// Skip test case if File type is not available

lib/util.js

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,9 +310,10 @@ async function transformFileInputsToBase64EncodedDataURIs(inputs) {
310310
// Walk a JavaScript object and transform the leaf values.
311311
async function transform(value, mapper) {
312312
if (Array.isArray(value)) {
313-
let copy = [];
313+
const copy = [];
314314
for (const val of value) {
315-
copy = await transform(val, mapper);
315+
const transformed = await transform(val, mapper);
316+
copy.push(transformed);
316317
}
317318
return copy;
318319
}

0 commit comments

Comments
 (0)