Skip to content

Commit cbde2c1

Browse files
authored
Add replicate.stream method (#169)
* Add replicate.stream method * Update README
1 parent 91c03f5 commit cbde2c1

File tree

4 files changed

+225
-0
lines changed

4 files changed

+225
-0
lines changed

README.md

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,49 @@ const input = { prompt: "a 19th century portrait of a raccoon gentleman wearing
172172
const output = await replicate.run(model, { input });
173173
```
174174

175+
### `replicate.stream`
176+
177+
Run a model and stream its output. Unlike [`replicate.prediction.create`](#replicatepredictionscreate), this method returns only the prediction output rather than the entire prediction object.
178+
179+
```js
180+
for await (const event of replicate.stream(identifier, options)) { /* ... */ }
181+
```
182+
183+
| name | type | description |
184+
| ------------------------------- | -------- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
185+
| `identifier` | string | **Required**. The model version identifier in the format `{owner}/{name}` or `{owner}/{name}:{version}`, for example `meta/llama-2-70b-chat` |
186+
| `options.input` | object | **Required**. An object with the model inputs. |
187+
| `options.webhook` | string | An HTTPS URL for receiving a webhook when the prediction has new output |
188+
| `options.webhook_events_filter` | string[] | An array of events which should trigger [webhooks](https://replicate.com/docs/webhooks). Allowable values are `start`, `output`, `logs`, and `completed` |
189+
| `options.signal` | object | An [AbortSignal](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal) to cancel the prediction |
190+
191+
Throws `Error` if the prediction failed.
192+
193+
Returns `AsyncGenerator<ServerSentEvent>` which yields the events of running the model.
194+
195+
Example:
196+
197+
```js
198+
for await (const event of replicate.stream("meta/llama-2-70b-chat")) {
199+
process.stdout.write(`${event}`);
200+
}
201+
```
202+
203+
### Server-sent events
204+
205+
A stream generates server-sent events with the following properties:
206+
207+
| name | type | description |
208+
| ------- | ------ | ---------------------------------------------------------------------------- |
209+
| `event` | string | The type of event. Possible values are `output`, `logs`, `error`, and `done` |
210+
| `data` | string | The event data |
211+
| `id` | string | The event id |
212+
| `retry` | number | The number of milliseconds to wait before reconnecting to the server |
213+
214+
As the prediction runs, the generator yields `output` and `logs` events. If an error occurs, the generator yields an `error` event with a JSON object containing the error message set to the `data` property. When the prediction is done, the generator yields a `done` event with an empty JSON object set to the `data` property.
215+
216+
Events with the `output` event type have their `toString()` method overridden to return the event data as a string. Other event types return an empty string.
217+
175218
### `replicate.models.get`
176219

177220
Get metadata for a public model or a private model that you own.

index.d.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,13 @@ declare module "replicate" {
7575
results: T[];
7676
}
7777

78+
export interface ServerSentEvent {
79+
event: string;
80+
data: string;
81+
id?: string;
82+
retry?: number;
83+
}
84+
7885
export default class Replicate {
7986
constructor(options?: {
8087
auth?: string;
@@ -103,6 +110,16 @@ declare module "replicate" {
103110
progress?: (prediction: Prediction) => void
104111
): Promise<object>;
105112

113+
stream(
114+
identifier: `${string}/${string}` | `${string}/${string}:${string}`,
115+
options: {
116+
input: object;
117+
webhook?: string;
118+
webhook_events_filter?: WebhookEventType[];
119+
signal?: AbortSignal;
120+
}
121+
): AsyncGenerator<ServerSentEvent>;
122+
106123
request(
107124
route: string | URL,
108125
options: {

index.js

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
const ApiError = require("./lib/error");
22
const ModelVersionIdentifier = require("./lib/identifier");
3+
const { Stream } = require("./lib/stream");
34
const { withAutomaticRetries } = require("./lib/util");
45

56
const collections = require("./lib/collections");
@@ -235,6 +236,47 @@ class Replicate {
235236
return response;
236237
}
237238

239+
/**
240+
* Stream a model and wait for its output.
241+
*
242+
* @param {string} identifier - Required. The model version identifier in the format "{owner}/{name}:{version}"
243+
* @param {object} options
244+
* @param {object} options.input - Required. An object with the model inputs
245+
* @param {string} [options.webhook] - An HTTPS URL for receiving a webhook when the prediction has new output
246+
* @param {string[]} [options.webhook_events_filter] - You can change which events trigger webhook requests by specifying webhook events (`start`|`output`|`logs`|`completed`)
247+
* @param {AbortSignal} [options.signal] - AbortSignal to cancel the prediction
248+
* @throws {Error} If the prediction failed
249+
* @yields {ServerSentEvent} Each streamed event from the prediction
250+
*/
251+
async *stream(ref, options) {
252+
const { wait, ...data } = options;
253+
254+
const identifier = ModelVersionIdentifier.parse(ref);
255+
256+
let prediction;
257+
if (identifier.version) {
258+
prediction = await this.predictions.create({
259+
...data,
260+
version: identifier.version,
261+
stream: true,
262+
});
263+
} else {
264+
prediction = await this.models.predictions.create(
265+
identifier.owner,
266+
identifier.name,
267+
{ ...data, stream: true }
268+
);
269+
}
270+
271+
if (prediction.urls && prediction.urls.stream) {
272+
const { signal } = options;
273+
const stream = new Stream(prediction.urls.stream, { signal });
274+
yield* stream;
275+
} else {
276+
throw new Error("Prediction does not support streaming");
277+
}
278+
}
279+
238280
/**
239281
* Paginate through a list of results.
240282
*

lib/stream.js

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
const { Readable } = require("stream");
2+
3+
/**
4+
* A server-sent event.
5+
*/
6+
class ServerSentEvent {
7+
/**
8+
* Create a new server-sent event.
9+
*
10+
* @param {string} event The event name.
11+
* @param {string} data The event data.
12+
* @param {string} id The event ID.
13+
* @param {number} retry The retry time.
14+
*/
15+
constructor(event, data, id, retry) {
16+
this.event = event;
17+
this.data = data;
18+
this.id = id;
19+
this.retry = retry;
20+
}
21+
22+
/**
23+
* Convert the event to a string.
24+
*/
25+
toString() {
26+
if (this.event === "output") {
27+
return this.data;
28+
}
29+
30+
return "";
31+
}
32+
}
33+
34+
/**
35+
* A stream of server-sent events.
36+
*/
37+
class Stream extends Readable {
38+
/**
39+
* Create a new stream of server-sent events.
40+
*
41+
* @param {string} url The URL to connect to.
42+
* @param {object} options The fetch options.
43+
*/
44+
constructor(url, options) {
45+
super();
46+
this.url = url;
47+
this.options = options;
48+
49+
this.event = null;
50+
this.data = [];
51+
this.lastEventId = null;
52+
this.retry = null;
53+
}
54+
55+
decode(line) {
56+
if (!line) {
57+
if (!this.event && !this.data.length && !this.lastEventId) {
58+
return null;
59+
}
60+
61+
const sse = new ServerSentEvent(
62+
this.event,
63+
this.data.join("\n"),
64+
this.lastEventId
65+
);
66+
67+
this.event = null;
68+
this.data = [];
69+
this.retry = null;
70+
71+
return sse;
72+
}
73+
74+
if (line.startsWith(":")) {
75+
return null;
76+
}
77+
78+
const [field, value] = line.split(": ");
79+
if (field === "event") {
80+
this.event = value;
81+
} else if (field === "data") {
82+
this.data.push(value);
83+
} else if (field === "id") {
84+
this.lastEventId = value;
85+
}
86+
87+
return null;
88+
}
89+
90+
async *[Symbol.asyncIterator]() {
91+
const response = await fetch(this.url, {
92+
...this.options,
93+
headers: {
94+
Accept: "text/event-stream",
95+
},
96+
});
97+
98+
for await (const chunk of response.body) {
99+
const decoder = new TextDecoder("utf-8");
100+
const text = decoder.decode(chunk);
101+
const lines = text.split("\n");
102+
for (const line of lines) {
103+
const sse = this.decode(line);
104+
if (sse) {
105+
if (sse.event === "error") {
106+
throw new Error(sse.data);
107+
}
108+
109+
yield sse;
110+
111+
if (sse.event === "done") {
112+
return;
113+
}
114+
}
115+
}
116+
}
117+
}
118+
}
119+
120+
module.exports = {
121+
Stream,
122+
ServerSentEvent,
123+
};

0 commit comments

Comments
 (0)