Skip to content

Commit ee62cc4

Browse files
matttzeke
authored andcommitted
Add support for files API endpoints
1 parent bb5ddaf commit ee62cc4

File tree

7 files changed

+213
-14
lines changed

7 files changed

+213
-14
lines changed

index.d.ts

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ declare module "replicate" {
3939
};
4040
}
4141

42+
export interface FileObject {
43+
id: string;
44+
name: string;
45+
content_type: string;
46+
size: number;
47+
etag: string;
48+
checksum: string;
49+
metadata: Record<string, any>;
50+
created_at: string;
51+
expires_at: string | null;
52+
urls: {
53+
get: string;
54+
};
55+
}
56+
4257
export interface Hardware {
4358
sku: string;
4459
name: string;
@@ -119,12 +134,16 @@ declare module "replicate" {
119134
input: Request | string,
120135
init?: RequestInit
121136
) => Promise<Response>;
137+
prepareInputs?: (
138+
input: Record<string, any>
139+
) => Promise<Record<string, any>>;
122140
});
123141

124142
auth: string;
125143
userAgent?: string;
126144
baseUrl?: string;
127145
fetch: (input: Request | string, init?: RequestInit) => Promise<Response>;
146+
prepareInputs: (input: Record<string, any>) => Promise<Record<string, any>>;
128147

129148
run(
130149
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
@@ -220,6 +239,16 @@ declare module "replicate" {
220239
list(): Promise<Page<Deployment>>;
221240
};
222241

242+
files: {
243+
create: (
244+
file: File | Blob,
245+
metadata?: Record<string, string | number | boolean | null>
246+
) => Promise<FileObject>;
247+
list: () => Promise<FileObject>;
248+
get: (file_id: string) => Promise<FileObject>;
249+
delete: (file_id: string) => Promise<FileObject>;
250+
};
251+
223252
hardware: {
224253
list(): Promise<Hardware[]>;
225254
};
@@ -310,4 +339,9 @@ declare module "replicate" {
310339
current: number;
311340
total: number;
312341
} | null;
342+
343+
export function transformFileInputs(
344+
inputs: Record<string, any>,
345+
options: { fallbackToDataURI: boolean }
346+
): Promise<Record<string, any>>;
313347
}

index.js

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,14 @@ const {
66
validateWebhook,
77
parseProgressFromLogs,
88
streamAsyncIterator,
9+
transformFileInputsToReplicateFileURLs,
10+
transformFileInputsToBase64EncodedDataURIs,
911
} = require("./lib/util");
1012

1113
const accounts = require("./lib/accounts");
1214
const collections = require("./lib/collections");
1315
const deployments = require("./lib/deployments");
16+
const files = require("./lib/files");
1417
const hardware = require("./lib/hardware");
1518
const models = require("./lib/models");
1619
const predictions = require("./lib/predictions");
@@ -46,6 +49,7 @@ class Replicate {
4649
* @param {string} options.userAgent - Identifier of your app
4750
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
4851
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
52+
* @param {Function} [options.prepareInput] - Function to prepare input data before sending it to the API.
4953
*/
5054
constructor(options = {}) {
5155
this.auth =
@@ -55,6 +59,15 @@ class Replicate {
5559
options.userAgent || `replicate-javascript/${packageJSON.version}`;
5660
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
5761
this.fetch = options.fetch || globalThis.fetch;
62+
this.prepareInputs =
63+
options.prepareInputs ||
64+
(async (inputs) => {
65+
try {
66+
return await transformFileInputsToReplicateFileURLs(this, inputs);
67+
} catch (error) {
68+
return await transformFileInputsToBase64EncodedDataURIs(inputs);
69+
}
70+
});
5871

5972
this.accounts = {
6073
current: accounts.current.bind(this),
@@ -75,6 +88,13 @@ class Replicate {
7588
},
7689
};
7790

91+
this.files = {
92+
create: files.create.bind(this),
93+
list: files.list.bind(this),
94+
get: files.get.bind(this),
95+
delete: files.delete.bind(this),
96+
};
97+
7898
this.hardware = {
7999
list: hardware.list.bind(this),
80100
};
@@ -230,10 +250,17 @@ class Replicate {
230250
}
231251
}
232252

253+
let body = undefined;
254+
if (data instanceof FormData) {
255+
body = data;
256+
} else if (data) {
257+
body = JSON.stringify(data);
258+
}
259+
233260
const init = {
234261
method,
235262
headers,
236-
body: data ? JSON.stringify(data) : undefined,
263+
body,
237264
};
238265

239266
const shouldRetry =

index.test.ts

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ describe("Replicate client", () => {
222222
expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq");
223223
});
224224

225-
test.each([
225+
const fileTestCases = [
226226
// Skip test case if File type is not available
227227
...(typeof File !== "undefined"
228228
? [
@@ -245,11 +245,47 @@ describe("Replicate client", () => {
245245
value: Buffer.from("hello world"),
246246
expected: "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=",
247247
},
248-
])(
248+
];
249+
250+
test.each(fileTestCases)(
251+
"converts a $type input into a Replicate file URL",
252+
async ({ value: data, expected }) => {
253+
nock(BASE_URL)
254+
.post("/files")
255+
.reply(201, {
256+
urls: {
257+
get: "https://replicate.com/api/files/123",
258+
},
259+
})
260+
.post("/predictions")
261+
.reply(201, (uri: string, body: Record<string, any>) => {
262+
return body;
263+
});
264+
265+
const prediction = await client.predictions.create({
266+
version:
267+
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
268+
input: {
269+
prompt: "Tell me a story",
270+
data,
271+
},
272+
stream: true,
273+
});
274+
275+
expect(prediction.input).toEqual({
276+
prompt: "Tell me a story",
277+
data: "https://replicate.com/api/files/123",
278+
});
279+
}
280+
);
281+
282+
test.each(fileTestCases)(
249283
"converts a $type input into a base64 encoded string",
250284
async ({ value: data, expected }) => {
251285
let actual: Record<string, any> | undefined;
252286
nock(BASE_URL)
287+
.post("/files")
288+
.reply(503, "Service Unavailable")
253289
.post("/predictions")
254290
.reply(201, (_uri: string, body: Record<string, any>) => {
255291
actual = body;

lib/deployments.js

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
const { transformFileInputs } = require("./util");
2-
31
/**
42
* Create a new prediction with a deployment
53
*
@@ -30,7 +28,7 @@ async function createPrediction(deployment_owner, deployment_name, options) {
3028
method: "POST",
3129
data: {
3230
...data,
33-
input: await transformFileInputs(input),
31+
input: await this.prepareInputs(input),
3432
stream,
3533
},
3634
}

lib/files.js

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/**
2+
* Create a file
3+
*
4+
* @param {object} file - Required. The file object.
5+
* @param {object} metadata - Optional. User-provided metadata associated with the file.
6+
* @returns {Promise<object>} - Resolves with the file data
7+
*/
8+
async function createFile(file, metadata = {}) {
9+
const form = new FormData();
10+
11+
let filename;
12+
let blob;
13+
if (file instanceof Blob) {
14+
fileName = file.name || `blob_${Date.now()}`;
15+
blob = file;
16+
} else if (Buffer.isBuffer(file)) {
17+
fileName = `buffer_${Date.now()}`;
18+
blob = new Blob(file, { type: "application/octet-stream" });
19+
} else {
20+
throw new Error("Invalid file argument, must be a Blob or File");
21+
}
22+
23+
form.append("content", blob, filename);
24+
form.append(
25+
"metadata",
26+
new Blob([JSON.stringify(metadata)], { type: "application/json" })
27+
);
28+
29+
const response = await this.request("/files", {
30+
method: "POST",
31+
data: form,
32+
headers: {
33+
"Content-Type": "multipart/form-data",
34+
},
35+
});
36+
37+
return response.json();
38+
}
39+
40+
/**
41+
* List all files
42+
*
43+
* @returns {Promise<object>} - Resolves with the files data
44+
*/
45+
async function listFiles() {
46+
const response = await this.request("/files", {
47+
method: "GET",
48+
});
49+
50+
return response.json();
51+
}
52+
53+
/**
54+
* Get a file
55+
*
56+
* @param {string} file_id - Required. The ID of the file.
57+
* @returns {Promise<object>} - Resolves with the file data
58+
*/
59+
async function getFile(file_id) {
60+
const response = await this.request(`/files/${file_id}`, {
61+
method: "GET",
62+
});
63+
64+
return response.json();
65+
}
66+
67+
/**
68+
* Delete a file
69+
*
70+
* @param {string} file_id - Required. The ID of the file.
71+
* @returns {Promise<object>} - Resolves with the deletion confirmation
72+
*/
73+
async function deleteFile(file_id) {
74+
const response = await this.request(`/files/${file_id}`, {
75+
method: "DELETE",
76+
});
77+
78+
return response.json();
79+
}
80+
81+
module.exports = {
82+
create: createFile,
83+
list: listFiles,
84+
get: getFile,
85+
delete: deleteFile,
86+
};

lib/predictions.js

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
const { transformFileInputs } = require("./util");
2-
31
/**
42
* Create a new prediction
53
*
@@ -30,7 +28,7 @@ async function createPrediction(options) {
3028
method: "POST",
3129
data: {
3230
...data,
33-
input: await transformFileInputs(input),
31+
input: await this.prepareInputs(input),
3432
version,
3533
stream,
3634
},
@@ -40,7 +38,7 @@ async function createPrediction(options) {
4038
method: "POST",
4139
data: {
4240
...data,
43-
input: await transformFileInputs(input),
41+
input: await this.prepareInputs(input),
4442
stream,
4543
},
4644
});

lib/util.js

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,26 @@ async function withAutomaticRetries(request, options = {}) {
215215
return request();
216216
}
217217

218+
/**
219+
* Walks the inputs and, for any File or Blob, tries to upload it to Replicate
220+
* and replaces the input with the URL of the uploaded file.
221+
*
222+
* @param {Replicate} client - The client used to upload the file
223+
* @param {object} inputs - The inputs to transform
224+
* @returns {object} - The transformed inputs
225+
* @throws {ApiError} If the request to upload the file fails
226+
*/
227+
async function transformFileInputsToReplicateFileURLs(client, inputs) {
228+
return await transform(inputs, async (value) => {
229+
if (value instanceof Blob || value instanceof Buffer) {
230+
const file = await client.files.create(value);
231+
return file.urls.get;
232+
}
233+
234+
return value;
235+
});
236+
}
237+
218238
const MAX_DATA_URI_SIZE = 10_000_000;
219239

220240
/**
@@ -225,9 +245,9 @@ const MAX_DATA_URI_SIZE = 10_000_000;
225245
* @returns {object} - The transformed inputs
226246
* @throws {Error} If the size of inputs exceeds a given threshould set by MAX_DATA_URI_SIZE
227247
*/
228-
async function transformFileInputs(inputs) {
248+
async function transformFileInputsToBase64EncodedDataURIs(inputs) {
229249
let totalBytes = 0;
230-
const result = await transform(inputs, async (value) => {
250+
return await transform(inputs, async (value) => {
231251
let buffer;
232252
let mime;
233253

@@ -258,8 +278,6 @@ async function transformFileInputs(inputs) {
258278

259279
return `data:${mime};base64,${data}`;
260280
});
261-
262-
return result;
263281
}
264282

265283
// Walk a JavaScript object and transform the leaf values.
@@ -394,4 +412,6 @@ module.exports = {
394412
withAutomaticRetries,
395413
parseProgressFromLogs,
396414
streamAsyncIterator,
415+
transformFileInputsToBase64EncodedDataURIs,
416+
transformFileInputsToReplicateFileURLs,
397417
};

0 commit comments

Comments
 (0)