Skip to content

Commit db66ee7

Browse files
matttaron
authored andcommitted
Add support for files API endpoints
1 parent 60a8e18 commit db66ee7

File tree

7 files changed

+213
-14
lines changed

7 files changed

+213
-14
lines changed

index.d.ts

+34
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ declare module "replicate" {
3939
};
4040
}
4141

42+
export interface FileObject {
43+
id: string;
44+
name: string;
45+
content_type: string;
46+
size: number;
47+
etag: string;
48+
checksum: string;
49+
metadata: Record<string, any>;
50+
created_at: string;
51+
expires_at: string | null;
52+
urls: {
53+
get: string;
54+
};
55+
}
56+
4257
export interface Hardware {
4358
sku: string;
4459
name: string;
@@ -119,12 +134,16 @@ declare module "replicate" {
119134
input: Request | string,
120135
init?: RequestInit
121136
) => Promise<Response>;
137+
prepareInputs?: (
138+
input: Record<string, any>
139+
) => Promise<Record<string, any>>;
122140
});
123141

124142
auth: string;
125143
userAgent?: string;
126144
baseUrl?: string;
127145
fetch: (input: Request | string, init?: RequestInit) => Promise<Response>;
146+
prepareInputs: (input: Record<string, any>) => Promise<Record<string, any>>;
128147

129148
run(
130149
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
@@ -196,6 +215,16 @@ declare module "replicate" {
196215
): Promise<Deployment>;
197216
};
198217

218+
files: {
219+
create: (
220+
file: File | Blob,
221+
metadata?: Record<string, string | number | boolean | null>
222+
) => Promise<FileObject>;
223+
list: () => Promise<FileObject>;
224+
get: (file_id: string) => Promise<FileObject>;
225+
delete: (file_id: string) => Promise<FileObject>;
226+
};
227+
199228
hardware: {
200229
list(): Promise<Hardware[]>;
201230
};
@@ -286,4 +315,9 @@ declare module "replicate" {
286315
current: number;
287316
total: number;
288317
} | null;
318+
319+
export function transformFileInputs(
320+
inputs: Record<string, any>,
321+
options: { fallbackToDataURI: boolean }
322+
): Promise<Record<string, any>>;
289323
}

index.js

+28-1
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@ const {
55
withAutomaticRetries,
66
validateWebhook,
77
parseProgressFromLogs,
8+
transformFileInputsToReplicateFileURLs,
9+
transformFileInputsToBase64EncodedDataURIs,
810
} = require("./lib/util");
911

1012
const accounts = require("./lib/accounts");
1113
const collections = require("./lib/collections");
1214
const deployments = require("./lib/deployments");
15+
const files = require("./lib/files");
1316
const hardware = require("./lib/hardware");
1417
const models = require("./lib/models");
1518
const predictions = require("./lib/predictions");
@@ -45,13 +48,23 @@ class Replicate {
4548
* @param {string} options.userAgent - Identifier of your app
4649
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
4750
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
51+
* @param {Function} [options.prepareInput] - Function to prepare input data before sending it to the API.
4852
*/
4953
constructor(options = {}) {
5054
this.auth = options.auth || process.env.REPLICATE_API_TOKEN;
5155
this.userAgent =
5256
options.userAgent || `replicate-javascript/${packageJSON.version}`;
5357
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
5458
this.fetch = options.fetch || globalThis.fetch;
59+
this.prepareInputs =
60+
options.prepareInputs ||
61+
(async (inputs) => {
62+
try {
63+
return await transformFileInputsToReplicateFileURLs(this, inputs);
64+
} catch (error) {
65+
return await transformFileInputsToBase64EncodedDataURIs(inputs);
66+
}
67+
});
5568

5669
this.accounts = {
5770
current: accounts.current.bind(this),
@@ -69,6 +82,13 @@ class Replicate {
6982
},
7083
};
7184

85+
this.files = {
86+
create: files.create.bind(this),
87+
list: files.list.bind(this),
88+
get: files.get.bind(this),
89+
delete: files.delete.bind(this),
90+
};
91+
7292
this.hardware = {
7393
list: hardware.list.bind(this),
7494
};
@@ -222,10 +242,17 @@ class Replicate {
222242
}
223243
}
224244

245+
let body = undefined;
246+
if (data instanceof FormData) {
247+
body = data;
248+
} else if (data) {
249+
body = JSON.stringify(data);
250+
}
251+
225252
const init = {
226253
method,
227254
headers,
228-
body: data ? JSON.stringify(data) : undefined,
255+
body,
229256
};
230257

231258
const shouldRetry =

index.test.ts

+38-2
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ describe("Replicate client", () => {
222222
expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq");
223223
});
224224

225-
test.each([
225+
const fileTestCases = [
226226
// Skip test case if File type is not available
227227
...(typeof File !== "undefined"
228228
? [
@@ -245,11 +245,47 @@ describe("Replicate client", () => {
245245
value: Buffer.from("hello world"),
246246
expected: "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=",
247247
},
248-
])(
248+
];
249+
250+
test.each(fileTestCases)(
251+
"converts a $type input into a Replicate file URL",
252+
async ({ value: data, expected }) => {
253+
nock(BASE_URL)
254+
.post("/files")
255+
.reply(201, {
256+
urls: {
257+
get: "https://replicate.com/api/files/123",
258+
},
259+
})
260+
.post("/predictions")
261+
.reply(201, (uri: string, body: Record<string, any>) => {
262+
return body;
263+
});
264+
265+
const prediction = await client.predictions.create({
266+
version:
267+
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
268+
input: {
269+
prompt: "Tell me a story",
270+
data,
271+
},
272+
stream: true,
273+
});
274+
275+
expect(prediction.input).toEqual({
276+
prompt: "Tell me a story",
277+
data: "https://replicate.com/api/files/123",
278+
});
279+
}
280+
);
281+
282+
test.each(fileTestCases)(
249283
"converts a $type input into a base64 encoded string",
250284
async ({ value: data, expected }) => {
251285
let actual: Record<string, any> | undefined;
252286
nock(BASE_URL)
287+
.post("/files")
288+
.reply(503, "Service Unavailable")
253289
.post("/predictions")
254290
.reply(201, (uri: string, body: Record<string, any>) => {
255291
actual = body;

lib/deployments.js

+1-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
const { transformFileInputs } = require("./util");
2-
31
/**
42
* Create a new prediction with a deployment
53
*
@@ -30,7 +28,7 @@ async function createPrediction(deployment_owner, deployment_name, options) {
3028
method: "POST",
3129
data: {
3230
...data,
33-
input: await transformFileInputs(input),
31+
input: await this.prepareInputs(input),
3432
stream,
3533
},
3634
}

lib/files.js

+86
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/**
2+
* Create a file
3+
*
4+
* @param {object} file - Required. The file object.
5+
* @param {object} metadata - Optional. User-provided metadata associated with the file.
6+
* @returns {Promise<object>} - Resolves with the file data
7+
*/
8+
async function createFile(file, metadata = {}) {
9+
const form = new FormData();
10+
11+
let filename;
12+
let blob;
13+
if (file instanceof Blob) {
14+
fileName = file.name || `blob_${Date.now()}`;
15+
blob = file;
16+
} else if (Buffer.isBuffer(file)) {
17+
fileName = `buffer_${Date.now()}`;
18+
blob = new Blob(file, { type: "application/octet-stream" });
19+
} else {
20+
throw new Error("Invalid file argument, must be a Blob or File");
21+
}
22+
23+
form.append("content", blob, filename);
24+
form.append(
25+
"metadata",
26+
new Blob([JSON.stringify(metadata)], { type: "application/json" })
27+
);
28+
29+
const response = await this.request("/files", {
30+
method: "POST",
31+
data: form,
32+
headers: {
33+
"Content-Type": "multipart/form-data",
34+
},
35+
});
36+
37+
return response.json();
38+
}
39+
40+
/**
41+
* List all files
42+
*
43+
* @returns {Promise<object>} - Resolves with the files data
44+
*/
45+
async function listFiles() {
46+
const response = await this.request("/files", {
47+
method: "GET",
48+
});
49+
50+
return response.json();
51+
}
52+
53+
/**
54+
* Get a file
55+
*
56+
* @param {string} file_id - Required. The ID of the file.
57+
* @returns {Promise<object>} - Resolves with the file data
58+
*/
59+
async function getFile(file_id) {
60+
const response = await this.request(`/files/${file_id}`, {
61+
method: "GET",
62+
});
63+
64+
return response.json();
65+
}
66+
67+
/**
68+
* Delete a file
69+
*
70+
* @param {string} file_id - Required. The ID of the file.
71+
* @returns {Promise<object>} - Resolves with the deletion confirmation
72+
*/
73+
async function deleteFile(file_id) {
74+
const response = await this.request(`/files/${file_id}`, {
75+
method: "DELETE",
76+
});
77+
78+
return response.json();
79+
}
80+
81+
module.exports = {
82+
create: createFile,
83+
list: listFiles,
84+
get: getFile,
85+
delete: deleteFile,
86+
};

lib/predictions.js

+2-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
const { transformFileInputs } = require("./util");
2-
31
/**
42
* Create a new prediction
53
*
@@ -30,7 +28,7 @@ async function createPrediction(options) {
3028
method: "POST",
3129
data: {
3230
...data,
33-
input: await transformFileInputs(input),
31+
input: await this.prepareInputs(input),
3432
version,
3533
stream,
3634
},
@@ -40,7 +38,7 @@ async function createPrediction(options) {
4038
method: "POST",
4139
data: {
4240
...data,
43-
input: await transformFileInputs(input),
41+
input: await this.prepareInputs(input),
4442
stream,
4543
},
4644
});

lib/util.js

+24-4
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,26 @@ async function withAutomaticRetries(request, options = {}) {
156156
return request();
157157
}
158158

159+
/**
160+
* Walks the inputs and, for any File or Blob, tries to upload it to Replicate
161+
* and replaces the input with the URL of the uploaded file.
162+
*
163+
* @param {Replicate} client - The client used to upload the file
164+
* @param {object} inputs - The inputs to transform
165+
* @returns {object} - The transformed inputs
166+
* @throws {ApiError} If the request to upload the file fails
167+
*/
168+
async function transformFileInputsToReplicateFileURLs(client, inputs) {
169+
return await transform(inputs, async (value) => {
170+
if (value instanceof Blob || value instanceof Buffer) {
171+
const file = await client.files.create(value);
172+
return file.urls.get;
173+
}
174+
175+
return value;
176+
});
177+
}
178+
159179
const MAX_DATA_URI_SIZE = 10_000_000;
160180

161181
/**
@@ -166,9 +186,9 @@ const MAX_DATA_URI_SIZE = 10_000_000;
166186
* @returns {object} - The transformed inputs
167187
* @throws {Error} If the size of inputs exceeds a given threshould set by MAX_DATA_URI_SIZE
168188
*/
169-
async function transformFileInputs(inputs) {
189+
async function transformFileInputsToBase64EncodedDataURIs(inputs) {
170190
let totalBytes = 0;
171-
const result = await transform(inputs, async (value) => {
191+
return await transform(inputs, async (value) => {
172192
let buffer;
173193
let mime;
174194

@@ -199,8 +219,6 @@ async function transformFileInputs(inputs) {
199219

200220
return `data:${mime};base64,${data}`;
201221
});
202-
203-
return result;
204222
}
205223

206224
// Walk a JavaScript object and transform the leaf values.
@@ -297,4 +315,6 @@ module.exports = {
297315
validateWebhook,
298316
withAutomaticRetries,
299317
parseProgressFromLogs,
318+
transformFileInputsToBase64EncodedDataURIs,
319+
transformFileInputsToReplicateFileURLs,
300320
};

0 commit comments

Comments
 (0)