Skip to content

Commit 60a8e18

Browse files
authored
Add parseProgress helper function (#207)
* Add parsePredictionProgress helper function * Rename parsePredictionProgress to parseProgress * Annotate possible null return value in jsdoc * Annotate possible null return value in type definition * Rename parseProgress to parseProgressFromLogs * Expand documentation of parseProgressFromLogs
1 parent a0b8f96 commit 60a8e18

File tree

4 files changed

+150
-26
lines changed

4 files changed

+150
-26
lines changed

index.d.ts

+6
Original file line numberDiff line numberDiff line change
@@ -280,4 +280,10 @@ declare module "replicate" {
280280
},
281281
secret: string
282282
): boolean;
283+
284+
export function parseProgressFromLogs(logs: Prediction | string): {
285+
percentage: number;
286+
current: number;
287+
total: number;
288+
} | null;
283289
}

index.js

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
const ApiError = require("./lib/error");
22
const ModelVersionIdentifier = require("./lib/identifier");
33
const { Stream } = require("./lib/stream");
4-
const { withAutomaticRetries, validateWebhook } = require("./lib/util");
4+
const {
5+
withAutomaticRetries,
6+
validateWebhook,
7+
parseProgressFromLogs,
8+
} = require("./lib/util");
59

610
const accounts = require("./lib/accounts");
711
const collections = require("./lib/collections");
@@ -375,3 +379,4 @@ class Replicate {
375379

376380
module.exports = Replicate;
377381
module.exports.validateWebhook = validateWebhook;
382+
module.exports.parseProgressFromLogs = parseProgressFromLogs;

index.test.ts

+86-24
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import Replicate, {
44
Model,
55
Prediction,
66
validateWebhook,
7+
parseProgressFromLogs,
78
} from "replicate";
89
import nock from "nock";
910
import fetch from "cross-fetch";
@@ -888,63 +889,124 @@ describe("Replicate client", () => {
888889
});
889890

890891
describe("run", () => {
891-
test("Calls the correct API routes for a version", async () => {
892-
const firstPollingRequest = true;
893-
892+
test("Calls the correct API routes", async () => {
894893
nock(BASE_URL)
895894
.post("/predictions")
896895
.reply(201, {
897896
id: "ufawqhfynnddngldkgtslldrkq",
898897
status: "starting",
898+
logs: null,
899899
})
900900
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
901-
.twice()
902901
.reply(200, {
903902
id: "ufawqhfynnddngldkgtslldrkq",
904903
status: "processing",
904+
logs: [
905+
"Using seed: 12345",
906+
"0%| | 0/5 [00:00<?, ?it/s]",
907+
"20%|██ | 1/5 [00:00<00:01, 21.38it/s]",
908+
"40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]",
909+
].join("\n"),
910+
})
911+
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
912+
.reply(200, {
913+
id: "ufawqhfynnddngldkgtslldrkq",
914+
status: "processing",
915+
logs: [
916+
"Using seed: 12345",
917+
"0%| | 0/5 [00:00<?, ?it/s]",
918+
"20%|██ | 1/5 [00:00<00:01, 21.38it/s]",
919+
"40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]",
920+
"60%|████▍ | 3/5 [00:01<00:01, 22.46it/s]",
921+
"80%|████████ | 4/5 [00:01<00:00, 22.86it/s]",
922+
].join("\n"),
905923
})
906924
.get("/predictions/ufawqhfynnddngldkgtslldrkq")
907925
.reply(200, {
908926
id: "ufawqhfynnddngldkgtslldrkq",
909927
status: "succeeded",
910928
output: "Goodbye!",
929+
logs: [
930+
"Using seed: 12345",
931+
"0%| | 0/5 [00:00<?, ?it/s]",
932+
"20%|██ | 1/5 [00:00<00:01, 21.38it/s]",
933+
"40%|████▍ | 2/5 [00:01<00:01, 22.46it/s]",
934+
"60%|████▍ | 3/5 [00:01<00:01, 22.46it/s]",
935+
"80%|████████ | 4/5 [00:01<00:00, 22.86it/s]",
936+
"100%|██████████| 5/5 [00:02<00:00, 22.26it/s]",
937+
].join("\n"),
911938
});
912939

913-
const progress = jest.fn();
940+
const callback = jest.fn();
914941

915942
const output = await client.run(
916943
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
917944
{
918945
input: { text: "Hello, world!" },
919946
wait: { interval: 1 },
920947
},
921-
progress
948+
(prediction) => {
949+
const progress = parseProgressFromLogs(prediction);
950+
callback(prediction, progress);
951+
}
922952
);
923953

924954
expect(output).toBe("Goodbye!");
925955

926-
expect(progress).toHaveBeenNthCalledWith(1, {
927-
id: "ufawqhfynnddngldkgtslldrkq",
928-
status: "starting",
929-
});
956+
expect(callback).toHaveBeenNthCalledWith(
957+
1,
958+
{
959+
id: "ufawqhfynnddngldkgtslldrkq",
960+
status: "starting",
961+
logs: null,
962+
},
963+
null
964+
);
930965

931-
expect(progress).toHaveBeenNthCalledWith(2, {
932-
id: "ufawqhfynnddngldkgtslldrkq",
933-
status: "processing",
934-
});
966+
expect(callback).toHaveBeenNthCalledWith(
967+
2,
968+
{
969+
id: "ufawqhfynnddngldkgtslldrkq",
970+
status: "processing",
971+
logs: expect.any(String),
972+
},
973+
{
974+
percentage: 0.4,
975+
current: 2,
976+
total: 5,
977+
}
978+
);
935979

936-
expect(progress).toHaveBeenNthCalledWith(3, {
937-
id: "ufawqhfynnddngldkgtslldrkq",
938-
status: "processing",
939-
});
980+
expect(callback).toHaveBeenNthCalledWith(
981+
3,
982+
{
983+
id: "ufawqhfynnddngldkgtslldrkq",
984+
status: "processing",
985+
logs: expect.any(String),
986+
},
987+
{
988+
percentage: 0.8,
989+
current: 4,
990+
total: 5,
991+
}
992+
);
940993

941-
expect(progress).toHaveBeenNthCalledWith(4, {
942-
id: "ufawqhfynnddngldkgtslldrkq",
943-
status: "succeeded",
944-
output: "Goodbye!",
945-
});
994+
expect(callback).toHaveBeenNthCalledWith(
995+
4,
996+
{
997+
id: "ufawqhfynnddngldkgtslldrkq",
998+
status: "succeeded",
999+
logs: expect.any(String),
1000+
output: "Goodbye!",
1001+
},
1002+
{
1003+
percentage: 1.0,
1004+
current: 5,
1005+
total: 5,
1006+
}
1007+
);
9461008

947-
expect(progress).toHaveBeenCalledTimes(4);
1009+
expect(callback).toHaveBeenCalledTimes(4);
9481010
});
9491011

9501012
test("Calls the correct API routes for a model", async () => {

lib/util.js

+52-1
Original file line numberDiff line numberDiff line change
@@ -246,4 +246,55 @@ function isPlainObject(value) {
246246
);
247247
}
248248

249-
module.exports = { transformFileInputs, validateWebhook, withAutomaticRetries };
249+
/**
250+
* Parse progress from prediction logs.
251+
*
252+
* This function supports log statements in the following format,
253+
* which are generated by https://github.com/tqdm/tqdm and similar libraries:
254+
*
255+
* ```
256+
* 76%|████████████████████████████ | 7568/10000 [00:33<00:10, 229.00it/s]
257+
* ```
258+
*
259+
* @example
260+
* const progress = parseProgressFromLogs("76%|████████████████████████████ | 7568/10000 [00:33<00:10, 229.00it/s]");
261+
* console.log(progress);
262+
* // {
263+
* // percentage: 0.76,
264+
* // current: 7568,
265+
* // total: 10000,
266+
* // }
267+
*
268+
* @param {object|string} input - A prediction object or string.
269+
* @returns {(object|null)} - An object with the percentage, current, and total, or null if no progress can be parsed.
270+
*/
271+
function parseProgressFromLogs(input) {
272+
const logs = typeof input === "object" && input.logs ? input.logs : input;
273+
if (!logs || typeof logs !== "string") {
274+
return null;
275+
}
276+
277+
const pattern = /^\s*(\d+)%\s*\|.+?\|\s*(\d+)\/(\d+)/;
278+
const lines = logs.split("\n").reverse();
279+
280+
for (const line of lines) {
281+
const matches = line.match(pattern);
282+
283+
if (matches && matches.length === 4) {
284+
return {
285+
percentage: parseInt(matches[1], 10) / 100,
286+
current: parseInt(matches[2], 10),
287+
total: parseInt(matches[3], 10),
288+
};
289+
}
290+
}
291+
292+
return null;
293+
}
294+
295+
module.exports = {
296+
transformFileInputs,
297+
validateWebhook,
298+
withAutomaticRetries,
299+
parseProgressFromLogs,
300+
};

0 commit comments

Comments
 (0)