Skip to content

Commit 8551a97

Browse files
authored
Merge branch 'main' into mr/add-get-deployment
2 parents 59d1933 + c0d2a01 commit 8551a97

File tree

5 files changed

+185
-3
lines changed

5 files changed

+185
-3
lines changed

index.d.ts

+25
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ declare module "replicate" {
105105
};
106106
};
107107
}
108+
109+
export interface WebhookSecret {
110+
key: string;
111+
}
108112

109113
export default class Replicate {
110114
constructor(options?: {
@@ -254,5 +258,26 @@ declare module "replicate" {
254258
cancel(training_id: string): Promise<Training>;
255259
list(): Promise<Page<Training>>;
256260
};
261+
262+
webhooks: {
263+
default: {
264+
secret: {
265+
get(): Promise<WebhookSecret>;
266+
};
267+
};
268+
};
257269
}
270+
271+
export function validateWebhook(
272+
requestData:
273+
| Request
274+
| {
275+
id?: string;
276+
timestamp?: string;
277+
body: string;
278+
secret?: string;
279+
signature?: string;
280+
},
281+
secret: string
282+
): boolean;
258283
}

index.js

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
const ApiError = require("./lib/error");
22
const ModelVersionIdentifier = require("./lib/identifier");
33
const { Stream } = require("./lib/stream");
4-
const { withAutomaticRetries } = require("./lib/util");
4+
const { withAutomaticRetries, validateWebhook } = require("./lib/util");
55

66
const accounts = require("./lib/accounts");
77
const collections = require("./lib/collections");
@@ -10,6 +10,7 @@ const hardware = require("./lib/hardware");
1010
const models = require("./lib/models");
1111
const predictions = require("./lib/predictions");
1212
const trainings = require("./lib/trainings");
13+
const webhooks = require("./lib/webhooks");
1314

1415
const packageJSON = require("./package.json");
1516

@@ -91,6 +92,14 @@ class Replicate {
9192
cancel: trainings.cancel.bind(this),
9293
list: trainings.list.bind(this),
9394
};
95+
96+
this.webhooks = {
97+
default: {
98+
secret: {
99+
get: webhooks.default.secret.get.bind(this),
100+
},
101+
},
102+
};
94103
}
95104

96105
/**
@@ -365,3 +374,4 @@ class Replicate {
365374
}
366375

367376
module.exports = Replicate;
377+
module.exports.validateWebhook = validateWebhook;

index.test.ts

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import { expect, jest, test } from "@jest/globals";
2-
import Replicate, { ApiError, Model, Prediction } from "replicate";
2+
import Replicate, {
3+
ApiError,
4+
Model,
5+
Prediction,
6+
validateWebhook,
7+
} from "replicate";
38
import nock from "nock";
49
import fetch from "cross-fetch";
510

@@ -1037,5 +1042,39 @@ describe("Replicate client", () => {
10371042
});
10381043
});
10391044

1045+
describe("webhooks.default.secret.get", () => {
1046+
test("Calls the correct API route", async () => {
1047+
nock(BASE_URL).get("/webhooks/default/secret").reply(200, {
1048+
key: "whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH",
1049+
});
1050+
1051+
const secret = await client.webhooks.default.secret.get();
1052+
expect(secret.key).toBe("whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH");
1053+
});
1054+
1055+
test("Can be used to validate webhook", async () => {
1056+
// Test case from https://github.com/svix/svix-webhooks/blob/b41728cd98a7e7004a6407a623f43977b82fcba4/javascript/src/webhook.test.ts#L190-L200
1057+
const request = new Request("http://test.host/webhook", {
1058+
method: "POST",
1059+
headers: {
1060+
"Content-Type": "application/json",
1061+
"Webhook-ID": "msg_p5jXN8AQM9LWM0D4loKWxJek",
1062+
"Webhook-Timestamp": "1614265330",
1063+
"Webhook-Signature":
1064+
"v1,g0hM9SsE+OTPJTGt/tmIKtSyZlE3uFJELVlNIOLJ1OE=",
1065+
},
1066+
body: `{"test": 2432232314}`,
1067+
});
1068+
1069+
// This is a test secret and should not be used in production
1070+
const secret = "whsec_MfKQ9r8GKYqrTwjUPD8ILPZIo2LaLaSw";
1071+
1072+
const isValid = await validateWebhook(request, secret);
1073+
expect(isValid).toBe(true);
1074+
});
1075+
1076+
// Add more tests for error handling, edge cases, etc.
1077+
});
1078+
10401079
// Continue with tests for other methods
10411080
});

lib/util.js

+89-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,93 @@
1+
const crypto = require("node:crypto");
2+
13
const ApiError = require("./error");
24

5+
/**
6+
* @see {@link validateWebhook}
7+
* @overload
8+
* @param {object} requestData - The request data
9+
* @param {string} requestData.id - The webhook ID header from the incoming request.
10+
* @param {string} requestData.timestamp - The webhook timestamp header from the incoming request.
11+
* @param {string} requestData.body - The raw body of the incoming webhook request.
12+
* @param {string} requestData.secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method.
13+
* @param {string} requestData.signature - The webhook signature header from the incoming request, comprising one or more space-delimited signatures.
14+
*/
15+
16+
/**
17+
* @see {@link validateWebhook}
18+
* @overload
19+
* @param {object} requestData - The request object
20+
* @param {object} requestData.headers - The request headers
21+
* @param {string} requestData.headers["webhook-id"] - The webhook ID header from the incoming request
22+
* @param {string} requestData.headers["webhook-timestamp"] - The webhook timestamp header from the incoming request
23+
* @param {string} requestData.headers["webhook-signature"] - The webhook signature header from the incoming request, comprising one or more space-delimited signatures
24+
* @param {string} requestData.body - The raw body of the incoming webhook request
25+
* @param {string} secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method
26+
*/
27+
28+
/**
29+
* Validate a webhook signature
30+
*
31+
* @returns {boolean} - True if the signature is valid
32+
* @throws {Error} - If the request is missing required headers, body, or secret
33+
*/
34+
async function validateWebhook(requestData, secret) {
35+
let { id, timestamp, body, signature } = requestData;
36+
const signingSecret = secret || requestData.secret;
37+
38+
if (requestData && requestData.headers && requestData.body) {
39+
id = requestData.headers.get("webhook-id");
40+
timestamp = requestData.headers.get("webhook-timestamp");
41+
signature = requestData.headers.get("webhook-signature");
42+
body = requestData.body;
43+
}
44+
45+
if (body instanceof ReadableStream || body.readable) {
46+
try {
47+
const chunks = [];
48+
for await (const chunk of body) {
49+
chunks.push(Buffer.from(chunk));
50+
}
51+
body = Buffer.concat(chunks).toString("utf8");
52+
} catch (err) {
53+
throw new Error(`Error reading body: ${err.message}`);
54+
}
55+
} else if (body instanceof Buffer) {
56+
body = body.toString("utf8");
57+
} else if (typeof body !== "string") {
58+
throw new Error("Invalid body type");
59+
}
60+
61+
if (!id || !timestamp || !signature) {
62+
throw new Error("Missing required webhook headers");
63+
}
64+
65+
if (!body) {
66+
throw new Error("Missing required body");
67+
}
68+
69+
if (!signingSecret) {
70+
throw new Error("Missing required secret");
71+
}
72+
73+
const signedContent = `${id}.${timestamp}.${body}`;
74+
75+
const secretBytes = Buffer.from(signingSecret.split("_")[1], "base64");
76+
77+
const computedSignature = crypto
78+
.createHmac("sha256", secretBytes)
79+
.update(signedContent)
80+
.digest("base64");
81+
82+
const expectedSignatures = signature
83+
.split(" ")
84+
.map((sig) => sig.split(",")[1]);
85+
86+
return expectedSignatures.some(
87+
(expectedSignature) => expectedSignature === computedSignature
88+
);
89+
}
90+
391
/**
492
* Automatically retry a request if it fails with an appropriate status code.
593
*
@@ -68,4 +156,4 @@ async function withAutomaticRetries(request, options = {}) {
68156
return request();
69157
}
70158

71-
module.exports = { withAutomaticRetries };
159+
module.exports = { validateWebhook, withAutomaticRetries };

lib/webhooks.js

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/**
2+
* Get the default webhook signing secret
3+
*
4+
* @returns {Promise<object>} Resolves with the signing secret for the default webhook
5+
*/
6+
async function getDefaultWebhookSecret() {
7+
const response = await this.request("/webhooks/default/secret", {
8+
method: "GET",
9+
});
10+
11+
return response.json();
12+
}
13+
14+
module.exports = {
15+
default: {
16+
secret: {
17+
get: getDefaultWebhookSecret,
18+
},
19+
},
20+
};

0 commit comments

Comments
 (0)