Skip to content

Commit cc3d281

Browse files
matttaron
andauthored
Add support for files API endpoints (#184)
* Add support for files API endpoints * Apply suggestions from code review Co-authored-by: Aron Carroll <[email protected]> * Remove prepareInputs in favor of FileEncodingStrategy This allows the user to determine if file uploads, falling back to base64 should be used or just sticking to one approach. The tests have been updated to validate the file upload payload and to ensure the url is correctly passed to the prediction create method. * Remove replicate.files API * Tell Biome to ignore .wrangler directory --------- Co-authored-by: Aron Carroll <[email protected]>
1 parent 87bd3ab commit cc3d281

File tree

8 files changed

+244
-13
lines changed

8 files changed

+244
-13
lines changed

biome.json

+5-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
{
22
"$schema": "https://biomejs.dev/schemas/1.0.0/schema.json",
33
"files": {
4-
"ignore": [".wrangler", "vendor/*"]
4+
"ignore": [
5+
".wrangler",
6+
"node_modules",
7+
"vendor/*"
8+
]
59
},
610
"formatter": {
711
"indentStyle": "space",

index.d.ts

+19
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ declare module "replicate" {
3939
};
4040
}
4141

42+
export interface FileObject {
43+
id: string;
44+
name: string;
45+
content_type: string;
46+
size: number;
47+
etag: string;
48+
checksum: string;
49+
metadata: Record<string, unknown>;
50+
created_at: string;
51+
expires_at: string | null;
52+
urls: {
53+
get: string;
54+
};
55+
}
56+
4257
export interface Hardware {
4358
sku: string;
4459
name: string;
@@ -93,6 +108,8 @@ declare module "replicate" {
93108

94109
export type Training = Prediction;
95110

111+
export type FileEncodingStrategy = "default" | "upload" | "data-uri";
112+
96113
export interface Page<T> {
97114
previous?: string;
98115
next?: string;
@@ -119,12 +136,14 @@ declare module "replicate" {
119136
input: Request | string,
120137
init?: RequestInit
121138
) => Promise<Response>;
139+
fileEncodingStrategy?: FileEncodingStrategy;
122140
});
123141

124142
auth: string;
125143
userAgent?: string;
126144
baseUrl?: string;
127145
fetch: (input: Request | string, init?: RequestInit) => Promise<Response>;
146+
fileEncodingStrategy: FileEncodingStrategy;
128147

129148
run(
130149
identifier: `${string}/${string}` | `${string}/${string}:${string}`,

index.js

+10-1
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ class Replicate {
4646
* @param {string} options.userAgent - Identifier of your app
4747
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
4848
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
49+
* @param {"default" | "upload" | "data-uri"} [options.fileEncodingStrategy] - Determines the file encoding strategy to use
4950
*/
5051
constructor(options = {}) {
5152
this.auth =
@@ -55,6 +56,7 @@ class Replicate {
5556
options.userAgent || `replicate-javascript/${packageJSON.version}`;
5657
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
5758
this.fetch = options.fetch || globalThis.fetch;
59+
this.fileEncodingStrategy = options.fileEncodingStrategy ?? "default";
5860

5961
this.accounts = {
6062
current: accounts.current.bind(this),
@@ -230,10 +232,17 @@ class Replicate {
230232
}
231233
}
232234

235+
let body = undefined;
236+
if (data instanceof FormData) {
237+
body = data;
238+
} else if (data) {
239+
body = JSON.stringify(data);
240+
}
241+
233242
const init = {
234243
method,
235244
headers,
236-
body: data ? JSON.stringify(data) : undefined,
245+
body,
237246
};
238247

239248
const shouldRetry =

index.test.ts

+59-3
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,13 @@ describe("Replicate client", () => {
222222
expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq");
223223
});
224224

225-
test.each([
225+
const fileTestCases = [
226226
// Skip test case if File type is not available
227227
...(typeof File !== "undefined"
228228
? [
229229
{
230230
type: "file",
231-
value: new File(["hello world"], "hello.txt", {
231+
value: new File(["hello world"], "file_hello.txt", {
232232
type: "text/plain",
233233
}),
234234
expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=",
@@ -245,11 +245,67 @@ describe("Replicate client", () => {
245245
value: Buffer.from("hello world"),
246246
expected: "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=",
247247
},
248-
])(
248+
];
249+
250+
test.each(fileTestCases)(
251+
"converts a $type input into a Replicate file URL",
252+
async ({ value: data, type }) => {
253+
const mockedFetch = jest.spyOn(client, "fetch");
254+
255+
nock(BASE_URL)
256+
.post("/files")
257+
.matchHeader("Content-Type", "multipart/form-data")
258+
.reply(201, {
259+
urls: {
260+
get: "https://replicate.com/api/files/123",
261+
},
262+
})
263+
.post(
264+
"/predictions",
265+
(body) => body.input.data === "https://replicate.com/api/files/123"
266+
)
267+
.reply(201, (_uri: string, body: Record<string, any>) => {
268+
return body;
269+
});
270+
271+
const prediction = await client.predictions.create({
272+
version:
273+
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
274+
input: {
275+
prompt: "Tell me a story",
276+
data,
277+
},
278+
stream: true,
279+
});
280+
281+
expect(client.fetch).toHaveBeenCalledWith(
282+
new URL("https://api.replicate.com/v1/files"),
283+
{
284+
method: "POST",
285+
body: expect.any(FormData),
286+
headers: expect.objectContaining({
287+
"Content-Type": "multipart/form-data",
288+
}),
289+
}
290+
);
291+
const form = mockedFetch.mock.calls[0][1]?.body as FormData;
292+
// @ts-ignore
293+
expect(form?.get("content")?.name).toMatch(new RegExp(`^${type}_`));
294+
295+
expect(prediction.input).toEqual({
296+
prompt: "Tell me a story",
297+
data: "https://replicate.com/api/files/123",
298+
});
299+
}
300+
);
301+
302+
test.each(fileTestCases)(
249303
"converts a $type input into a base64 encoded string",
250304
async ({ value: data, expected }) => {
251305
let actual: Record<string, any> | undefined;
252306
nock(BASE_URL)
307+
.post("/files")
308+
.reply(503, "Service Unavailable")
253309
.post("/predictions")
254310
.reply(201, (_uri: string, body: Record<string, any>) => {
255311
actual = body;

lib/deployments.js

+5-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ async function createPrediction(deployment_owner, deployment_name, options) {
3030
method: "POST",
3131
data: {
3232
...data,
33-
input: await transformFileInputs(input),
33+
input: await transformFileInputs(
34+
this,
35+
input,
36+
this.fileEncodingStrategy
37+
),
3438
stream,
3539
},
3640
}

lib/files.js

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/**
2+
* Create a file
3+
*
4+
* @param {object} file - Required. The file object.
5+
* @param {object} metadata - Optional. User-provided metadata associated with the file.
6+
* @returns {Promise<object>} - Resolves with the file data
7+
*/
8+
async function createFile(file, metadata = {}) {
9+
const form = new FormData();
10+
11+
let filename;
12+
let blob;
13+
if (file instanceof Blob) {
14+
filename = file.name || `blob_${Date.now()}`;
15+
blob = file;
16+
} else if (Buffer.isBuffer(file)) {
17+
filename = `buffer_${Date.now()}`;
18+
blob = new Blob(file, { type: "application/octet-stream" });
19+
} else {
20+
throw new Error("Invalid file argument, must be a Blob, File or Buffer");
21+
}
22+
23+
form.append("content", blob, filename);
24+
form.append(
25+
"metadata",
26+
new Blob([JSON.stringify(metadata)], { type: "application/json" })
27+
);
28+
29+
const response = await this.request("/files", {
30+
method: "POST",
31+
data: form,
32+
headers: {
33+
"Content-Type": "multipart/form-data",
34+
},
35+
});
36+
37+
return response.json();
38+
}
39+
40+
/**
41+
* List all files
42+
*
43+
* @returns {Promise<object>} - Resolves with the files data
44+
*/
45+
async function listFiles() {
46+
const response = await this.request("/files", {
47+
method: "GET",
48+
});
49+
50+
return response.json();
51+
}
52+
53+
/**
54+
* Get a file
55+
*
56+
* @param {string} file_id - Required. The ID of the file.
57+
* @returns {Promise<object>} - Resolves with the file data
58+
*/
59+
async function getFile(file_id) {
60+
const response = await this.request(`/files/${file_id}`, {
61+
method: "GET",
62+
});
63+
64+
return response.json();
65+
}
66+
67+
/**
68+
* Delete a file
69+
*
70+
* @param {string} file_id - Required. The ID of the file.
71+
* @returns {Promise<object>} - Resolves with the deletion confirmation
72+
*/
73+
async function deleteFile(file_id) {
74+
const response = await this.request(`/files/${file_id}`, {
75+
method: "DELETE",
76+
});
77+
78+
return response.json();
79+
}
80+
81+
module.exports = {
82+
create: createFile,
83+
list: listFiles,
84+
get: getFile,
85+
delete: deleteFile,
86+
};

lib/predictions.js

+10-2
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@ async function createPrediction(options) {
3030
method: "POST",
3131
data: {
3232
...data,
33-
input: await transformFileInputs(input),
33+
input: await transformFileInputs(
34+
this,
35+
input,
36+
this.fileEncodingStrategy
37+
),
3438
version,
3539
stream,
3640
},
@@ -40,7 +44,11 @@ async function createPrediction(options) {
4044
method: "POST",
4145
data: {
4246
...data,
43-
input: await transformFileInputs(input),
47+
input: await transformFileInputs(
48+
this,
49+
input,
50+
this.fileEncodingStrategy
51+
),
4452
stream,
4553
},
4654
});

lib/util.js

+50-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
const ApiError = require("./error");
2+
const { create: createFile } = require("./files");
23

34
/**
45
* @see {@link validateWebhook}
@@ -209,12 +210,58 @@ async function withAutomaticRetries(request, options = {}) {
209210
}
210211
attempts += 1;
211212
}
212-
/* eslint-enable no-await-in-loop */
213213
} while (attempts < maxRetries);
214214

215215
return request();
216216
}
217217

218+
/**
219+
* Walks the inputs and, for any File or Blob, tries to upload it to Replicate
220+
* and replaces the input with the URL of the uploaded file.
221+
*
222+
* @param {Replicate} client - The client used to upload the file
223+
* @param {object} inputs - The inputs to transform
224+
* @param {"default" | "upload" | "data-uri"} strategy - Whether to upload files to Replicate, encode as dataURIs or try both.
225+
* @returns {object} - The transformed inputs
226+
* @throws {ApiError} If the request to upload the file fails
227+
*/
228+
async function transformFileInputs(client, inputs, strategy) {
229+
switch (strategy) {
230+
case "data-uri":
231+
return await transformFileInputsToBase64EncodedDataURIs(client, inputs);
232+
case "upload":
233+
return await transformFileInputsToReplicateFileURLs(client, inputs);
234+
case "default":
235+
try {
236+
return await transformFileInputsToReplicateFileURLs(client, inputs);
237+
} catch (error) {
238+
return await transformFileInputsToBase64EncodedDataURIs(inputs);
239+
}
240+
default:
241+
throw new Error(`Unexpected file upload strategy: ${strategy}`);
242+
}
243+
}
244+
245+
/**
246+
* Walks the inputs and, for any File or Blob, tries to upload it to Replicate
247+
* and replaces the input with the URL of the uploaded file.
248+
*
249+
* @param {Replicate} client - The client used to upload the file
250+
* @param {object} inputs - The inputs to transform
251+
* @returns {object} - The transformed inputs
252+
* @throws {ApiError} If the request to upload the file fails
253+
*/
254+
async function transformFileInputsToReplicateFileURLs(client, inputs) {
255+
return await transform(inputs, async (value) => {
256+
if (value instanceof Blob || value instanceof Buffer) {
257+
const file = await createFile.call(client, value);
258+
return file.urls.get;
259+
}
260+
261+
return value;
262+
});
263+
}
264+
218265
const MAX_DATA_URI_SIZE = 10_000_000;
219266

220267
/**
@@ -225,9 +272,9 @@ const MAX_DATA_URI_SIZE = 10_000_000;
225272
* @returns {object} - The transformed inputs
226273
* @throws {Error} If the size of inputs exceeds a given threshould set by MAX_DATA_URI_SIZE
227274
*/
228-
async function transformFileInputs(inputs) {
275+
async function transformFileInputsToBase64EncodedDataURIs(inputs) {
229276
let totalBytes = 0;
230-
const result = await transform(inputs, async (value) => {
277+
return await transform(inputs, async (value) => {
231278
let buffer;
232279
let mime;
233280

@@ -258,8 +305,6 @@ async function transformFileInputs(inputs) {
258305

259306
return `data:${mime};base64,${data}`;
260307
});
261-
262-
return result;
263308
}
264309

265310
// Walk a JavaScript object and transform the leaf values.

0 commit comments

Comments
 (0)