Skip to content

Commit 9c54b7e

Browse files
committed
Call the onProgress handler with the canceled prediction
Previously when aborting a `run()` request we were dropping the final canceled prediction object and calling the `onProgress` callback with a stale "processing" object.
1 parent 9f1a89c commit 9c54b7e

File tree

2 files changed

+29
-4
lines changed

2 files changed

+29
-4
lines changed

index.js

+6-2
Original file line numberDiff line numberDiff line change
@@ -162,15 +162,19 @@ class Replicate {
162162
progress(updatedPrediction);
163163
}
164164

165-
if (signal && signal.aborted) {
166-
await this.predictions.cancel(updatedPrediction.id);
165+
// We handle the cancel later in the function.
166+
if (signal?.aborted) {
167167
return true; // stop polling
168168
}
169169

170170
return false; // continue polling
171171
}
172172
);
173173

174+
if (signal?.aborted) {
175+
prediction = await this.predictions.cancel(prediction.id);
176+
}
177+
174178
// Call progress callback with the completed prediction object
175179
if (progress) {
176180
progress(prediction);

index.test.ts

+23-2
Original file line numberDiff line numberDiff line change
@@ -1258,19 +1258,40 @@ describe("Replicate client", () => {
12581258
status: "canceled",
12591259
});
12601260

1261-
await client.run(
1261+
const onProgress = jest.fn();
1262+
const output = await client.run(
12621263
"owner/model:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa",
12631264
{
12641265
input: { text: "Hello, world!" },
12651266
signal,
1266-
}
1267+
},
1268+
onProgress
12671269
);
12681270

12691271
expect(body).toBeDefined();
12701272
expect(body?.["signal"]).toBeUndefined();
12711273
expect(signal.aborted).toBe(true);
12721274
expect(output).toBeUndefined();
12731275

1276+
expect(onProgress).toHaveBeenNthCalledWith(
1277+
1,
1278+
expect.objectContaining({
1279+
status: "processing",
1280+
})
1281+
);
1282+
expect(onProgress).toHaveBeenNthCalledWith(
1283+
2,
1284+
expect.objectContaining({
1285+
status: "processing",
1286+
})
1287+
);
1288+
expect(onProgress).toHaveBeenNthCalledWith(
1289+
3,
1290+
expect.objectContaining({
1291+
status: "canceled",
1292+
})
1293+
);
1294+
12741295
scope.done();
12751296
});
12761297
});

0 commit comments

Comments
 (0)