Skip to content

Commit 37c1f74

Browse files
aronzeke
authored andcommitted
Remove prepareInputs in favor of FileEncodingStrategy
This allows the user to determine if file uploads, falling back to base64 should be used or just sticking to one approach. The tests have been updated to validate the file upload payload and to ensure the url is correctly passed to the prediction create method.
1 parent 1bf586b commit 37c1f74

File tree

6 files changed

+76
-27
lines changed

6 files changed

+76
-27
lines changed

index.d.ts

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ declare module "replicate" {
108108

109109
export type Training = Prediction;
110110

111+
export type FileEncodingStrategy = "default" | "upload" | "data-uri";
112+
111113
export interface Page<T> {
112114
previous?: string;
113115
next?: string;
@@ -134,16 +136,14 @@ declare module "replicate" {
134136
input: Request | string,
135137
init?: RequestInit
136138
) => Promise<Response>;
137-
prepareInputs?: (
138-
input: Record<string, any>
139-
) => Promise<Record<string, any>>;
139+
fileEncodingStrategy?: FileEncodingStrategy;
140140
});
141141

142142
auth: string;
143143
userAgent?: string;
144144
baseUrl?: string;
145145
fetch: (input: Request | string, init?: RequestInit) => Promise<Response>;
146-
prepareInputs: (input: Record<string, any>) => Promise<Record<string, any>>;
146+
fileEncodingStrategy: FileEncodingStrategy;
147147

148148
run(
149149
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
@@ -339,9 +339,4 @@ declare module "replicate" {
339339
current: number;
340340
total: number;
341341
} | null;
342-
343-
export function transformFileInputs(
344-
inputs: Record<string, any>,
345-
options: { fallbackToDataURI: boolean }
346-
): Promise<Record<string, any>>;
347342
}

index.js

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class Replicate {
4949
* @param {string} options.userAgent - Identifier of your app
5050
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
5151
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
52-
* @param {Function} [options.prepareInput] - Function to prepare input data before sending it to the API.
52+
* @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use
5353
*/
5454
constructor(options = {}) {
5555
this.auth =
@@ -59,15 +59,7 @@ class Replicate {
5959
options.userAgent || `replicate-javascript/${packageJSON.version}`;
6060
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
6161
this.fetch = options.fetch || globalThis.fetch;
62-
this.prepareInputs =
63-
options.prepareInputs ||
64-
(async (inputs) => {
65-
try {
66-
return await transformFileInputsToReplicateFileURLs(this, inputs);
67-
} catch (error) {
68-
return await transformFileInputsToBase64EncodedDataURIs(inputs);
69-
}
70-
});
62+
this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default";
7163

7264
this.accounts = {
7365
current: accounts.current.bind(this),

index.test.ts

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ describe("Replicate client", () => {
228228
? [
229229
{
230230
type: "file",
231-
value: new File(["hello world"], "hello.txt", {
231+
value: new File(["hello world"], "file_hello.txt", {
232232
type: "text/plain",
233233
}),
234234
expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=",
@@ -249,16 +249,22 @@ describe("Replicate client", () => {
249249

250250
test.each(fileTestCases)(
251251
"converts a $type input into a Replicate file URL",
252-
async ({ value: data, expected }) => {
252+
async ({ value: data, type }) => {
253+
const mockedFetch = jest.spyOn(client, "fetch");
254+
253255
nock(BASE_URL)
254256
.post("/files")
257+
.matchHeader("Content-Type", "multipart/form-data")
255258
.reply(201, {
256259
urls: {
257260
get: "https://replicate.com/api/files/123",
258261
},
259262
})
260-
.post("/predictions")
261-
.reply(201, (uri: string, body: Record<string, any>) => {
263+
.post(
264+
"/predictions",
265+
(body) => body.input.data === "https://replicate.com/api/files/123"
266+
)
267+
.reply(201, (_uri: string, body: Record<string, any>) => {
262268
return body;
263269
});
264270

@@ -272,6 +278,20 @@ describe("Replicate client", () => {
272278
stream: true,
273279
});
274280

281+
expect(client.fetch).toHaveBeenCalledWith(
282+
new URL("https://api.replicate.com/v1/files"),
283+
{
284+
method: "POST",
285+
body: expect.any(FormData),
286+
headers: expect.objectContaining({
287+
"Content-Type": "multipart/form-data",
288+
}),
289+
}
290+
);
291+
const form = mockedFetch.mock.calls[0][1]?.body as FormData;
292+
// @ts-ignore
293+
expect(form?.get("content")?.name).toMatch(new RegExp(`^${type}_`));
294+
275295
expect(prediction.input).toEqual({
276296
prompt: "Tell me a story",
277297
data: "https://replicate.com/api/files/123",

lib/deployments.js

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
const { transformFileInputs } = require("./util");
2+
13
/**
24
* Create a new prediction with a deployment
35
*
@@ -28,7 +30,11 @@ async function createPrediction(deployment_owner, deployment_name, options) {
2830
method: "POST",
2931
data: {
3032
...data,
31-
input: await this.prepareInputs(input),
33+
input: await transformFileInputs(
34+
this,
35+
input,
36+
this.fileEncodingStrategy
37+
),
3238
stream,
3339
},
3440
}

lib/predictions.js

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
const { transformFileInputs } = require("./util");
2+
13
/**
24
* Create a new prediction
35
*
@@ -28,7 +30,11 @@ async function createPrediction(options) {
2830
method: "POST",
2931
data: {
3032
...data,
31-
input: await this.prepareInputs(input),
33+
input: await transformFileInputs(
34+
this,
35+
input,
36+
this.fileEncodingStrategy
37+
),
3238
version,
3339
stream,
3440
},
@@ -38,7 +44,11 @@ async function createPrediction(options) {
3844
method: "POST",
3945
data: {
4046
...data,
41-
input: await this.prepareInputs(input),
47+
input: await transformFileInputs(
48+
this,
49+
input,
50+
this.fileEncodingStrategy
51+
),
4252
stream,
4353
},
4454
});

lib/util.js

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,38 @@ async function withAutomaticRetries(request, options = {}) {
209209
}
210210
attempts += 1;
211211
}
212-
/* eslint-enable no-await-in-loop */
213212
} while (attempts < maxRetries);
214213

215214
return request();
216215
}
217216

217+
/**
218+
* Walks the inputs and, for any File or Blob, tries to upload it to Replicate
219+
* and replaces the input with the URL of the uploaded file.
220+
*
221+
* @param {Replicate} client - The client used to upload the file
222+
* @param {object} inputs - The inputs to transform
223+
* @param {"default" | "upload" | "data-uri"} strategy - Whether to upload files to Replicate, encode as dataURIs or try both.
224+
* @returns {object} - The transformed inputs
225+
* @throws {ApiError} If the request to upload the file fails
226+
*/
227+
async function transformFileInputs(client, inputs, strategy) {
228+
switch (strategy) {
229+
case "data-uri":
230+
return await transformFileInputsToBase64EncodedDataURIs(client, inputs);
231+
case "upload":
232+
return await transformFileInputsToReplicateFileURLs(client, inputs);
233+
case "default":
234+
try {
235+
return await transformFileInputsToReplicateFileURLs(client, inputs);
236+
} catch (error) {
237+
return await transformFileInputsToBase64EncodedDataURIs(inputs);
238+
}
239+
default:
240+
throw new Error(`Unexpected file upload strategy: ${strategy}`);
241+
}
242+
}
243+
218244
/**
219245
* Walks the inputs and, for any File or Blob, tries to upload it to Replicate
220246
* and replaces the input with the URL of the uploaded file.

0 commit comments

Comments
 (0)