Skip to content

Commit 839e476

Browse files
aronmattt
andauthored
Automatically transform binary inputs into data URIs (#198)
Co-authored-by: Mattt Zmuda <[email protected]>
1 parent d09067c commit 839e476

File tree

5 files changed

+163
-17
lines changed

5 files changed

+163
-17
lines changed

README.md

+2-11
Original file line numberDiff line numberDiff line change
@@ -73,26 +73,17 @@ console.log(prediction.output);
7373
// ['https://replicate.delivery/pbxt/RoaxeXqhL0xaYyLm6w3bpGwF5RaNBjADukfFnMbhOyeoWBdhA/out-0.png']
7474
```
7575

76-
To run a model that takes a file input, pass a URL to a publicly accessible file. Or, for smaller files (<10MB), you can convert file data into a base64-encoded data URI and pass that directly:
76+
To run a model that takes a file input, pass a URL to a publicly accessible file. Or, for smaller files (<10MB), you can pass the data directly.
7777

7878
```js
7979
const fs = require("node:fs/promises");
8080

8181
// Or when using ESM.
8282
// import fs from "node:fs/promises";
8383

84-
// Read the file into a buffer
85-
const data = await fs.readFile("path/to/image.png");
86-
// Convert the buffer into a base64-encoded string
87-
const base64 = data.toString("base64");
88-
// Set MIME type for PNG image
89-
const mimeType = "image/png";
90-
// Create the data URI
91-
const dataURI = `data:${mimeType};base64,${base64}`;
92-
9384
const model = "nightmareai/real-esrgan:42fed1c4974146d4d2414e2be2c5277c7fcf05fcc3a73abf41610695738c1d7b";
9485
const input = {
95-
image: dataURI,
86+
image: await fs.readFile("path/to/image.png"),
9687
};
9788
const output = await replicate.run(model, { input });
9889
// ['https://replicate.delivery/mgxm/e7b0e122-9daa-410e-8cde-006c7308ff4d/output.png']

index.test.ts

+48
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,54 @@ describe("Replicate client", () => {
221221
expect(prediction.id).toBe("ufawqhfynnddngldkgtslldrkq");
222222
});
223223

224+
test.each([
225+
// Skip test case if File type is not available
226+
...(typeof File !== "undefined"
227+
? [
228+
{
229+
type: "file",
230+
value: new File(["hello world"], "hello.txt", {
231+
type: "text/plain",
232+
}),
233+
expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=",
234+
},
235+
]
236+
: []),
237+
{
238+
type: "blob",
239+
value: new Blob(["hello world"], { type: "text/plain" }),
240+
expected: "data:text/plain;base64,aGVsbG8gd29ybGQ=",
241+
},
242+
{
243+
type: "buffer",
244+
value: Buffer.from("hello world"),
245+
expected: "data:application/octet-stream;base64,aGVsbG8gd29ybGQ=",
246+
},
247+
])(
248+
"converts a $type input into a base64 encoded string",
249+
async ({ value: data, expected }) => {
250+
let actual: Record<string, any> | undefined;
251+
nock(BASE_URL)
252+
.post("/predictions")
253+
.reply(201, (uri: string, body: Record<string, any>) => {
254+
actual = body;
255+
return body;
256+
});
257+
258+
await client.predictions.create({
259+
version:
260+
"5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
261+
input: {
262+
prompt: "Tell me a story",
263+
data,
264+
},
265+
stream: true,
266+
});
267+
268+
expect(actual?.input.data).toEqual(expected);
269+
}
270+
);
271+
224272
test("Passes stream parameter to API endpoint", async () => {
225273
nock(BASE_URL)
226274
.post("/predictions")

lib/deployments.js

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
const { transformFileInputs } = require("./util");
2+
13
/**
24
* Create a new prediction with a deployment
35
*
@@ -11,7 +13,7 @@
1113
* @returns {Promise<object>} Resolves with the created prediction data
1214
*/
1315
async function createPrediction(deployment_owner, deployment_name, options) {
14-
const { stream, ...data } = options;
16+
const { stream, input, ...data } = options;
1517

1618
if (data.webhook) {
1719
try {
@@ -26,7 +28,11 @@ async function createPrediction(deployment_owner, deployment_name, options) {
2628
`/deployments/${deployment_owner}/${deployment_name}/predictions`,
2729
{
2830
method: "POST",
29-
data: { ...data, stream },
31+
data: {
32+
...data,
33+
input: await transformFileInputs(input),
34+
stream,
35+
},
3036
}
3137
);
3238

lib/predictions.js

+14-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
const { transformFileInputs } = require("./util");
2+
13
/**
24
* Create a new prediction
35
*
@@ -11,7 +13,7 @@
1113
* @returns {Promise<object>} Resolves with the created prediction
1214
*/
1315
async function createPrediction(options) {
14-
const { model, version, stream, ...data } = options;
16+
const { model, version, stream, input, ...data } = options;
1517

1618
if (data.webhook) {
1719
try {
@@ -26,12 +28,21 @@ async function createPrediction(options) {
2628
if (version) {
2729
response = await this.request("/predictions", {
2830
method: "POST",
29-
data: { ...data, stream, version },
31+
data: {
32+
...data,
33+
input: await transformFileInputs(input),
34+
version,
35+
stream,
36+
},
3037
});
3138
} else if (model) {
3239
response = await this.request(`/models/${model}/predictions`, {
3340
method: "POST",
34-
data: { ...data, stream },
41+
data: {
42+
...data,
43+
input: await transformFileInputs(input),
44+
stream,
45+
},
3546
});
3647
} else {
3748
throw new Error("Either model or version must be specified");

lib/util.js

+91-1
Original file line numberDiff line numberDiff line change
@@ -156,4 +156,94 @@ async function withAutomaticRetries(request, options = {}) {
156156
return request();
157157
}
158158

159-
module.exports = { validateWebhook, withAutomaticRetries };
159+
const MAX_DATA_URI_SIZE = 10_000_000;
160+
161+
/**
162+
* Walks the inputs and transforms any binary data found into a
163+
* base64-encoded data URI.
164+
*
165+
* @param {object} inputs - The inputs to transform
166+
* @returns {object} - The transformed inputs
167+
* @throws {Error} If the size of inputs exceeds a given threshould set by MAX_DATA_URI_SIZE
168+
*/
169+
async function transformFileInputs(inputs) {
170+
let totalBytes = 0;
171+
const result = await transform(inputs, async (value) => {
172+
let buffer;
173+
let mime;
174+
175+
if (value instanceof Blob) {
176+
// Currently we use a NodeJS only API for base64 encoding, as
177+
// we move to support the browser we could support either using
178+
// btoa (which does string encoding), the FileReader API or
179+
// a JavaScript implenentation like base64-js.
180+
// See: https://developer.mozilla.org/en-US/docs/Glossary/Base64
181+
// See: https://github.com/beatgammit/base64-js
182+
buffer = Buffer.from(await value.arrayBuffer());
183+
mime = value.type;
184+
} else if (Buffer.isBuffer(value)) {
185+
buffer = value;
186+
} else {
187+
return value;
188+
}
189+
190+
totalBytes += buffer.byteLength;
191+
if (totalBytes > MAX_DATA_URI_SIZE) {
192+
throw new Error(
193+
`Combined filesize of prediction ${totalBytes} bytes exceeds 10mb limit for inline encoding, please provide URLs instead`
194+
);
195+
}
196+
197+
const data = buffer.toString("base64");
198+
mime = mime ?? "application/octet-stream";
199+
200+
return `data:${mime};base64,${data}`;
201+
});
202+
203+
return result;
204+
}
205+
206+
// Walk a JavaScript object and transform the leaf values.
207+
async function transform(value, mapper) {
208+
if (Array.isArray(value)) {
209+
let copy = [];
210+
for (const val of value) {
211+
copy = await transform(val, mapper);
212+
}
213+
return copy;
214+
}
215+
216+
if (isPlainObject(value)) {
217+
const copy = {};
218+
for (const key of Object.keys(value)) {
219+
copy[key] = await transform(value[key], mapper);
220+
}
221+
return copy;
222+
}
223+
224+
return await mapper(value);
225+
}
226+
227+
// Test for a plain JS object.
228+
// Source: lodash.isPlainObject
229+
function isPlainObject(value) {
230+
const isObjectLike = typeof value === "object" && value !== null;
231+
if (!isObjectLike || String(value) !== "[object Object]") {
232+
return false;
233+
}
234+
const proto = Object.getPrototypeOf(value);
235+
if (proto === null) {
236+
return true;
237+
}
238+
const Ctor =
239+
Object.prototype.hasOwnProperty.call(proto, "constructor") &&
240+
proto.constructor;
241+
return (
242+
typeof Ctor === "function" &&
243+
Ctor instanceof Ctor &&
244+
Function.prototype.toString.call(Ctor) ===
245+
Function.prototype.toString.call(Object)
246+
);
247+
}
248+
249+
module.exports = { transformFileInputs, validateWebhook, withAutomaticRetries };

0 commit comments

Comments
 (0)