Skip to content

Commit 6a33a89

Browse files
committed
Tighten up the typings across the board
1 parent 8d080f8 commit 6a33a89

12 files changed

+100
-36
lines changed

index.js

+17-14
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@ require("./lib/types");
77
* @deprecated use exported Replicate class instead
88
*/
99
class DeprecatedReplicate extends ReplicateClass {
10-
/** @deprecated Use `import { Replicate } from "replicate";` instead */
10+
/**
11+
* @deprecated Use `import { Replicate } from "replicate";` instead
12+
* @param {ConstructorParameters<typeof ReplicateClass>[0]=} options
13+
*/
1114
// biome-ignore lint/complexity/noUselessConstructor: exists for the tsdoc comment
12-
constructor(...args) {
13-
super(...args);
15+
constructor(options) {
16+
super(options);
1417
}
1518
}
1619

@@ -77,20 +80,20 @@ module.exports = replicate;
7780
/**
7881
* @typedef {import("./lib/replicate")} Replicate
7982
* @typedef {import("./lib/error")} ApiError
80-
* @typedef {typeof import("./lib/types").Collection} Collection
81-
* @typedef {typeof import("./lib/types").ModelVersion} ModelVersion
82-
* @typedef {typeof import("./lib/types").Hardware} Hardware
83-
* @typedef {typeof import("./lib/types").Model} Model
84-
* @typedef {typeof import("./lib/types").Prediction} Prediction
85-
* @typedef {typeof import("./lib/types").Training} Training
86-
* @typedef {typeof import("./lib/types").ServerSentEvent} ServerSentEvent
87-
* @typedef {typeof import("./lib/types").Status} Status
88-
* @typedef {typeof import("./lib/types").Visibility} Visibility
89-
* @typedef {typeof import("./lib/types").WebhookEventType} WebhookEventType
83+
* @typedef {import("./lib/types").Collection} Collection
84+
* @typedef {import("./lib/types").ModelVersion} ModelVersion
85+
* @typedef {import("./lib/types").Hardware} Hardware
86+
* @typedef {import("./lib/types").Model} Model
87+
* @typedef {import("./lib/types").Prediction} Prediction
88+
* @typedef {import("./lib/types").Training} Training
89+
* @typedef {import("./lib/types").ServerSentEvent} ServerSentEvent
90+
* @typedef {import("./lib/types").Status} Status
91+
* @typedef {import("./lib/types").Visibility} Visibility
92+
* @typedef {import("./lib/types").WebhookEventType} WebhookEventType
9093
*/
9194

9295
/**
9396
* @template T
94-
* @typedef {typeof import("./lib/types").Page} Page
97+
* @typedef {import("./lib/types").Page<T>} Page
9598
*/
9699

integration/typescript/types.test.ts

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { ApiError, Hardware, Model, ModelVersion, Page, Prediction, Status, Training, Visibility, WebhookEventType } from "replicate";
1+
import { ApiError, Collection, Hardware, Model, ModelVersion, Page, Prediction, Status, Training, Visibility, WebhookEventType } from "replicate";
22

33
// NOTE: We export the constants to avoid unused varaible issues.
44

lib/collections.js

+6
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
/** @typedef {import("./types").Collection} Collection */
2+
/**
3+
* @template T
4+
* @typedef {import("./types").Page<T>} Page
5+
*/
6+
17
/**
28
* Fetch a model collection
39
*

lib/deployments.js

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
/** @typedef {import("./types").Prediction} Prediction */
2+
13
/**
24
* Create a new prediction with a deployment
35
*
46
* @param {string} deployment_owner - Required. The username of the user or organization who owns the deployment
57
* @param {string} deployment_name - Required. The name of the deployment
68
* @param {object} options
7-
* @param {object} options.input - Required. An object with the model inputs
9+
* @param {unknown} options.input - Required. An object with the model inputs
810
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false
911
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
1012
* @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)

lib/hardware.js

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
/** @typedef {import("./types").Hardware} Hardware */
12
/**
23
* List hardware
34
*

lib/identifier.js

+6-6
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,21 @@
22
* A reference to a model version in the format `owner/name` or `owner/name:version`.
33
*/
44
class ModelVersionIdentifier {
5-
/*
6-
* @param {string} Required. The model owner.
7-
* @param {string} Required. The model name.
8-
* @param {string} The model version.
5+
/**
6+
* @param {string} owner Required. The model owner.
7+
* @param {string} name Required. The model name.
8+
* @param {string | null=} version The model version.
99
*/
1010
constructor(owner, name, version = null) {
1111
this.owner = owner;
1212
this.name = name;
1313
this.version = version;
1414
}
1515

16-
/*
16+
/**
1717
* Parse a reference to a model version
1818
*
19-
* @param {string}
19+
* @param {string} ref
2020
* @returns {ModelVersionIdentifier}
2121
* @throws {Error} If the reference is invalid.
2222
*/

lib/models.js

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
/** @typedef {import("./types").Model} Model */
2+
/** @typedef {import("./types").ModelVersion} ModelVersion */
3+
/** @typedef {import("./types").Prediction} Prediction */
4+
/** @typedef {import("./types").Visibility} Visibility */
5+
/**
6+
* @template T
7+
* @typedef {import("./types").Page<T>} Page
8+
*/
9+
110
/**
211
* Get information about a model
312
*
@@ -69,7 +78,7 @@ async function listModels() {
6978
* @param {string} model_owner - Required. The name of the user or organization that will own the model. This must be the same as the user or organization that is making the API request. In other words, the API token used in the request must belong to this user or organization.
7079
* @param {string} model_name - Required. The name of the model. This must be unique among all models owned by the user or organization.
7180
* @param {object} options
72-
* @param {("public"|"private")} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model.
81+
* @param {Visibility} options.visibility - Required. Whether the model should be public or private. A public model can be viewed and run by anyone, whereas a private model can be viewed and run only by the user or organization members that own the model.
7382
* @param {string} options.hardware - Required. The SKU for the hardware used to run the model. Possible values can be found by calling `Replicate.hardware.list()`.
7483
* @param {string} options.description - A description of the model.
7584
* @param {string=} options.github_url - A URL for the model's source code on GitHub.

lib/predictions.js

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
/**
2+
* @template T
3+
* @typedef {import("./types").Page<T>} Page
4+
*/
5+
/** @typedef {import("./types").Prediction} Prediction */
6+
17
/**
28
* Create a new prediction
39
*
410
* @param {object} options
511
* @param {string=} options.model - The model (for official models)
612
* @param {string=} options.version - The model version.
7-
* @param {object} options.input - Required. An object with the model inputs
13+
* @param {unknown} options.input - Required. An object with the model inputs
814
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
915
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
1016
* @param {boolean} [options.stream] - Whether to stream the prediction output. Defaults to false

lib/replicate.js

+30-7
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,11 @@
1+
/**
2+
* @template T
3+
* @typedef {import("./types").Page<T>} Page
4+
*/
5+
6+
/** @typedef {import("./types").Prediction} Prediction */
7+
/** @typedef {import("./types").WebhookEventType} WebhookEventType */
8+
19
const ApiError = require("./error");
210
const ModelVersionIdentifier = require("./identifier");
311
const { Stream } = require("./stream");
@@ -49,34 +57,45 @@ module.exports = class Replicate {
4957
* const input = {text: 'Hello, world!'}
5058
* const output = await replicate.run(model, { input });
5159
*
52-
* @param {Object={}} options - Configuration options for the client
60+
* @param {Object} [options] - Configuration options for the client
5361
* @param {string} [options.auth] - API access token. Defaults to the `REPLICATE_API_TOKEN` environment variable.
5462
* @param {string} [options.userAgent] - Identifier of your app
5563
* @param {string} [options.baseUrl] - Defaults to https://api.replicate.com/v1
5664
* @param {Function} [options.fetch] - Fetch function to use. Defaults to `globalThis.fetch`
5765
*/
5866
constructor(options = {}) {
67+
/** @type {string} */
5968
this.auth = options.auth || process.env.REPLICATE_API_TOKEN;
69+
70+
/** @type {string} */
6071
this.userAgent =
6172
options.userAgent || `replicate-javascript/${packageJSON.version}`;
73+
74+
/** @type {string} */
6275
this.baseUrl = options.baseUrl || "https://api.replicate.com/v1";
76+
77+
/** @type {fetch} */
6378
this.fetch = options.fetch || globalThis.fetch;
6479

80+
/** @type {collections} */
6581
this.collections = {
6682
list: collections.list.bind(this),
6783
get: collections.get.bind(this),
6884
};
6985

86+
/** @type {deployments} */
7087
this.deployments = {
7188
predictions: {
7289
create: deployments.predictions.create.bind(this),
7390
},
7491
};
7592

93+
/** @type {hardware} */
7694
this.hardware = {
7795
list: hardware.list.bind(this),
7896
};
7997

98+
/** @type {models} */
8099
this.models = {
81100
get: models.get.bind(this),
82101
list: models.list.bind(this),
@@ -87,13 +106,15 @@ module.exports = class Replicate {
87106
},
88107
};
89108

109+
/** @type {predictions} */
90110
this.predictions = {
91111
create: predictions.create.bind(this),
92112
get: predictions.get.bind(this),
93113
cancel: predictions.cancel.bind(this),
94114
list: predictions.list.bind(this),
95115
};
96116

117+
/** @type {trainings} */
97118
this.trainings = {
98119
create: trainings.create.bind(this),
99120
get: trainings.get.bind(this),
@@ -111,12 +132,12 @@ module.exports = class Replicate {
111132
* @param {object} [options.wait] - Options for waiting for the prediction to finish
112133
* @param {number} [options.wait.interval] - Polling interval in milliseconds. Defaults to 500
113134
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
114-
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
135+
* @param {WebhookEventType[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
115136
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
116137
* @param {Function} [progress] - Callback function that receives the prediction object as it's updated. The function is called when the prediction is created, each time its updated while polling for completion, and when it's completed.
117138
* @throws {Error} If the reference is invalid
118139
* @throws {Error} If the prediction failed
119-
* @returns {Promise<object>} - Resolves with the output of running the model
140+
* @returns {Promise<Prediction>} - Resolves with the output of running the model
120141
*/
121142
async run(ref, options, progress) {
122143
const { wait, ...data } = options;
@@ -252,7 +273,7 @@ module.exports = class Replicate {
252273
/**
253274
* Stream a model and wait for its output.
254275
*
255-
* @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}"
276+
* @param {string} ref - Required. The model version identifier in the format "{owner}/{name}:{version}"
256277
* @param {object} options
257278
* @param {object} options.input - Required. An object with the model inputs
258279
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
@@ -300,8 +321,10 @@ module.exports = class Replicate {
300321
* for await (const page of replicate.paginate(replicate.predictions.list) {
301322
* console.log(page);
302323
* }
303-
* @param {Function} endpoint - Function that returns a promise for the next page of results
304-
* @yields {object[]} Each page of results
324+
* @template T
325+
* @param {() => Promise<Page<T>>} endpoint - Function that returns a promise for the next page of results
326+
* @yields {T[]} Each page of results
327+
* @returns {AsyncGenerator<T[], void, unknown>}
305328
*/
306329
async *paginate(endpoint) {
307330
const response = await endpoint();
@@ -327,7 +350,7 @@ module.exports = class Replicate {
327350
* @param {Function} [stop] - Async callback function that is called after each polling attempt. Receives the prediction object as an argument. Return false to cancel polling.
328351
* @throws {Error} If the prediction doesn't complete within the maximum number of attempts
329352
* @throws {Error} If the prediction failed
330-
* @returns {Promise<object>} Resolves with the completed prediction object
353+
* @returns {Promise<Prediction>} Resolves with the completed prediction object
331354
*/
332355
async wait(prediction, options, stop) {
333356
const { id } = prediction;

lib/stream.js

+9-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// Attempt to use readable-stream if available, attempt to use the built-in stream module.
2+
/** @type {import("stream").Readable} */
23
let Readable;
34
try {
45
Readable = require("readable-stream").Readable;
@@ -49,7 +50,7 @@ class Stream extends Readable {
4950
* Create a new stream of server-sent events.
5051
*
5152
* @param {string} url The URL to connect to.
52-
* @param {object} options The fetch options.
53+
* @param {RequestInit=} options The fetch options.
5354
*/
5455
constructor(url, options) {
5556
if (!Readable) {
@@ -63,11 +64,18 @@ class Stream extends Readable {
6364
this.options = options;
6465

6566
this.event = null;
67+
68+
/** @type {unknown[]} */
6669
this.data = [];
70+
71+
/** @type {string | null} */
6772
this.lastEventId = null;
73+
74+
/** @type {number | null} */
6875
this.retry = null;
6976
}
7077

78+
/** @param {string=} line */
7179
decode(line) {
7280
if (!line) {
7381
if (!this.event && !this.data.length && !this.lastEventId) {

lib/trainings.js

+7-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
/**
2+
* @template T
3+
* @typedef {import("./types").Page<T>} Page
4+
*/
5+
/** @typedef {import("./types").Training} Training */
6+
17
/**
28
* Create a new training
39
*
@@ -6,7 +12,7 @@
612
* @param {string} version_id - Required. The version ID
713
* @param {object} options
814
* @param {string} options.destination - Required. The destination for the trained version in the form "{username}/{model_name}"
9-
* @param {object} options.input - Required. An object with the model inputs
15+
* @param {unknown} options.input - Required. An object with the model inputs
1016
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the training updates
1117
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
1218
* @returns {Promise<Training>} Resolves with the data for the created training

lib/types.js

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
* @typedef {"public" | "private"} Visibility
44
* @typedef {"start" | "output" | "logs" | "completed"} WebhookEventType
55
*
6-
* @typedef {import('./lib/error')} ApiError
7-
*
86
* @typedef {Object} Collection
97
* @property {string} name
108
* @property {string} slug
@@ -55,7 +53,7 @@
5553
*
5654
* @typedef {Prediction} Training
5755
*
58-
* @property {Object} ServerSentEvent
56+
* @typedef {Object} ServerSentEvent
5957
* @property {string} event
6058
* @property {string} data
6159
* @property {string=} id
@@ -69,3 +67,5 @@
6967
* @property {string=} next
7068
* @property {T[]} results
7169
*/
70+
71+
module.exports = {};

0 commit comments

Comments
 (0)