-
Notifications
You must be signed in to change notification settings - Fork 219
/
Copy pathpredictions.js
106 lines (95 loc) · 2.89 KB
/
predictions.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
const { transformFileInputs } = require("./util");
/**
* Create a new prediction
*
* @param {object} options
* @param {string} options.model - The model.
* @param {string} options.version - The model version.
* @param {object} options.input - Required. An object with the model inputs
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false. Streaming is now enabled by default for all predictions. For more information, see https://replicate.com/changelog/2024-07-15-streams-always-available-stream-parameter-deprecated
* @returns {Promise<object>} Resolves with the created prediction
*/
async function createPrediction(options) {
const { model, version, input, ...data } = options;
if (data.webhook) {
try {
// eslint-disable-next-line no-new
new URL(data.webhook);
} catch (err) {
throw new Error("Invalid webhook URL");
}
}
let response;
if (version) {
response = await this.request("/predictions", {
method: "POST",
data: {
...data,
input: await transformFileInputs(
this,
input,
this.fileEncodingStrategy
),
version,
},
});
} else if (model) {
response = await this.request(`/models/${model}/predictions`, {
method: "POST",
data: {
...data,
input: await transformFileInputs(
this,
input,
this.fileEncodingStrategy
),
},
});
} else {
throw new Error("Either model or version must be specified");
}
return response.json();
}
/**
* Fetch a prediction by ID
*
* @param {number} prediction_id - Required. The prediction ID
* @returns {Promise<object>} Resolves with the prediction data
*/
async function getPrediction(prediction_id) {
const response = await this.request(`/predictions/${prediction_id}`, {
method: "GET",
});
return response.json();
}
/**
* Cancel a prediction by ID
*
* @param {string} prediction_id - Required. The training ID
* @returns {Promise<object>} Resolves with the data for the training
*/
async function cancelPrediction(prediction_id) {
const response = await this.request(`/predictions/${prediction_id}/cancel`, {
method: "POST",
});
return response.json();
}
/**
* List all predictions
*
* @returns {Promise<object>} - Resolves with a page of predictions
*/
async function listPredictions() {
const response = await this.request("/predictions", {
method: "GET",
});
return response.json();
}
module.exports = {
create: createPrediction,
get: getPrediction,
cancel: cancelPrediction,
list: listPredictions,
};