Skip to content

Commit 996b316

Browse files
committed
Remove prepareInputs in favor of FileEncodingStrategy
This allows the user to determine if file uploads, falling back to base64 should be used or just sticking to one approach. The tests have been updated to validate the file upload payload and to ensure the url is correctly passed to the prediction create method.
1 parent f115b91 commit 996b316

File tree

6 files changed

+76
-27
lines changed

6 files changed

+76
-27
lines changed

index.d.ts

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,8 @@ declare module "replicate" {
108108

109109
export type Training = Prediction;
110110

111+
export type FileEncodingStrategy = "default" | "upload" | "data-uri";
112+
111113
export interface Page<T> {
112114
previous?: string;
113115
next?: string;
@@ -134,16 +136,14 @@ declare module "replicate" {
134136
input: Request | string,
135137
init?: RequestInit
136138
) => Promise<Response>;
137-
prepareInputs?: (
138-
input: Record<string, any>
139-
) => Promise<Record<string, any>>;
139+
fileEncodingStrategy?: FileEncodingStrategy;
140140
});
141141

142142
auth: string;
143143
userAgent?: string;
144144
baseUrl?: string;
145145
fetch: (input: Request | string, init?: RequestInit) => Promise<Response>;
146-
prepareInputs: (input: Record<string, any>) => Promise<Record<string, any>>;
146+
fileEncodingStrategy: FileEncodingStrategy;
147147

148148
run(
149149
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
@@ -315,9 +315,4 @@ declare module "replicate" {
315315
current: number;
316316
total: number;
317317
} | null;
318-
319-
export function transformFileInputs(
320-
inputs: Record<string, any>,
321-
options: { fallbackToDataURI: boolean }
322-
): Promise<Record<string, any>>;
323318
}

index.js

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,15 @@ class Replicate {
4848
* @param {string} options.userAgent - Identifier of your app
4949
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
5050
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
51-
* @param {Function} [options.prepareInput] - Function to prepare input data before sending it to the API.
51+
* @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use
5252
*/
5353
constructor(options = {}) {
5454
this.auth = options.auth || process.env.REPLICATE_API_TOKEN;
5555
this.userAgent =
5656
options.userAgent || `replicate-javascript/${packageJSON.version}`;
5757
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
5858
this.fetch = options.fetch || globalThis.fetch;
59-
this.prepareInputs =
60-
options.prepareInputs ||
61-
(async (inputs) => {
62-
try {
63-
return await transformFileInputsToReplicateFileURLs(this, inputs);
64-
} catch (error) {
65-
return await transformFileInputsToBase64EncodedDataURIs(inputs);
66-
}
67-
});
59+
this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default";
6860

6961
this.accounts = {
7062
current: accounts.current.bind(this),

index.test.ts

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ describe("Replicate client", () => {
228228
? [
229229
{
230230
type: "file",
231-
value: new File(["hello world"], "hello.txt", {
231+
value: new File(["hello world"], "file_hello.txt", {
232232
type: "text/plain",
233233
}),
234234
expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=",
@@ -249,16 +249,22 @@ describe("Replicate client", () => {
249249

250250
test.each(fileTestCases)(
251251
"converts a $type input into a Replicate file URL",
252-
async ({ value: data, expected }) => {
252+
async ({ value: data, type }) => {
253+
const mockedFetch = jest.spyOn(client, "fetch");
254+
253255
nock(BASE_URL)
254256
.post("/files")
257+
.matchHeader("Content-Type", "multipart/form-data")
255258
.reply(201, {
256259
urls: {
257260
get: "https://replicate.com/api/files/123",
258261
},
259262
})
260-
.post("/predictions")
261-
.reply(201, (uri: string, body: Record<string, any>) => {
263+
.post(
264+
"/predictions",
265+
(body) => body.input.data === "https://replicate.com/api/files/123"
266+
)
267+
.reply(201, (_uri: string, body: Record<string, any>) => {
262268
return body;
263269
});
264270

@@ -272,6 +278,20 @@ describe("Replicate client", () => {
272278
stream: true,
273279
});
274280

281+
expect(client.fetch).toHaveBeenCalledWith(
282+
new URL("https://api.replicate.com/v1/files"),
283+
{
284+
method: "POST",
285+
body: expect.any(FormData),
286+
headers: expect.objectContaining({
287+
"Content-Type": "multipart/form-data",
288+
}),
289+
}
290+
);
291+
const form = mockedFetch.mock.calls[0][1]?.body as FormData;
292+
// @ts-ignore
293+
expect(form?.get("content")?.name).toMatch(new RegExp(`^${type}_`));
294+
275295
expect(prediction.input).toEqual({
276296
prompt: "Tell me a story",
277297
data: "https://replicate.com/api/files/123",

lib/deployments.js

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
const { transformFileInputs } = require("./util");
2+
13
/**
24
* Create a new prediction with a deployment
35
*
@@ -28,7 +30,11 @@ async function createPrediction(deployment_owner, deployment_name, options) {
2830
method: "POST",
2931
data: {
3032
...data,
31-
input: await this.prepareInputs(input),
33+
input: await transformFileInputs(
34+
this,
35+
input,
36+
this.fileEncodingStrategy
37+
),
3238
stream,
3339
},
3440
}

lib/predictions.js

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
const { transformFileInputs } = require("./util");
2+
13
/**
24
* Create a new prediction
35
*
@@ -28,7 +30,11 @@ async function createPrediction(options) {
2830
method: "POST",
2931
data: {
3032
...data,
31-
input: await this.prepareInputs(input),
33+
input: await transformFileInputs(
34+
this,
35+
input,
36+
this.fileEncodingStrategy
37+
),
3238
version,
3339
stream,
3440
},
@@ -38,7 +44,11 @@ async function createPrediction(options) {
3844
method: "POST",
3945
data: {
4046
...data,
41-
input: await this.prepareInputs(input),
47+
input: await transformFileInputs(
48+
this,
49+
input,
50+
this.fileEncodingStrategy
51+
),
4252
stream,
4353
},
4454
});

lib/util.js

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,12 +150,38 @@ async function withAutomaticRetries(request, options = {}) {
150150
}
151151
attempts += 1;
152152
}
153-
/* eslint-enable no-await-in-loop */
154153
} while (attempts < maxRetries);
155154

156155
return request();
157156
}
158157

158+
/**
159+
* Walks the inputs and, for any File or Blob, tries to upload it to Replicate
160+
* and replaces the input with the URL of the uploaded file.
161+
*
162+
* @param {Replicate} client - The client used to upload the file
163+
* @param {object} inputs - The inputs to transform
164+
* @param {"default" | "upload" | "data-uri"} strategy - Whether to upload files to Replicate, encode as dataURIs or try both.
165+
* @returns {object} - The transformed inputs
166+
* @throws {ApiError} If the request to upload the file fails
167+
*/
168+
async function transformFileInputs(client, inputs, strategy) {
169+
switch (strategy) {
170+
case "data-uri":
171+
return await transformFileInputsToBase64EncodedDataURIs(client, inputs);
172+
case "upload":
173+
return await transformFileInputsToReplicateFileURLs(client, inputs);
174+
case "default":
175+
try {
176+
return await transformFileInputsToReplicateFileURLs(client, inputs);
177+
} catch (error) {
178+
return await transformFileInputsToBase64EncodedDataURIs(inputs);
179+
}
180+
default:
181+
throw new Error(`Unexpected file upload strategy: ${strategy}`);
182+
}
183+
}
184+
159185
/**
160186
* Walks the inputs and, for any File or Blob, tries to upload it to Replicate
161187
* and replaces the input with the URL of the uploaded file.

0 commit comments

Comments
 (0)