Skip to content

Commit 88a1cac

Browse files
committed
Add support for validating webhooks
1 parent cafb2a5 commit 88a1cac

File tree

5 files changed

+167
-3
lines changed

5 files changed

+167
-3
lines changed

index.d.ts

+25
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,10 @@ declare module "replicate" {
8282
retry?: number;
8383
}
8484

85+
export interface WebhookSecret {
86+
key: string;
87+
}
88+
8589
export default class Replicate {
8690
constructor(options?: {
8791
auth?: string;
@@ -222,5 +226,26 @@ declare module "replicate" {
222226
cancel(training_id: string): Promise<Training>;
223227
list(): Promise<Page<Training>>;
224228
};
229+
230+
webhooks: {
231+
default: {
232+
secret: {
233+
get(): Promise<WebhookSecret>;
234+
};
235+
};
236+
};
225237
}
238+
239+
export function validateWebhook(
240+
requestData:
241+
| Request
242+
| {
243+
id?: string;
244+
timestamp?: string;
245+
body: string;
246+
secret?: string;
247+
signature?: string;
248+
},
249+
secret: string
250+
): boolean;
226251
}

index.js

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
const ApiError = require("./lib/error");
22
const ModelVersionIdentifier = require("./lib/identifier");
33
const { Stream } = require("./lib/stream");
4-
const { withAutomaticRetries } = require("./lib/util");
4+
const { withAutomaticRetries, validateWebhook } = require("./lib/util");
55

66
const collections = require("./lib/collections");
77
const deployments = require("./lib/deployments");
88
const hardware = require("./lib/hardware");
99
const models = require("./lib/models");
1010
const predictions = require("./lib/predictions");
1111
const trainings = require("./lib/trainings");
12+
const webhooks = require("./lib/webhooks");
1213

1314
const packageJSON = require("./package.json");
1415

@@ -85,6 +86,14 @@ class Replicate {
8586
cancel: trainings.cancel.bind(this),
8687
list: trainings.list.bind(this),
8788
};
89+
90+
this.webhooks = {
91+
default: {
92+
secret: {
93+
get: webhooks.default.secret.get.bind(this),
94+
},
95+
},
96+
};
8897
}
8998

9099
/**
@@ -359,3 +368,4 @@ class Replicate {
359368
}
360369

361370
module.exports = Replicate;
371+
module.exports.validateWebhook = validateWebhook;

index.test.ts

+38-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
import { expect, jest, test } from "@jest/globals";
2-
import Replicate, { ApiError, Model, Prediction } from "replicate";
2+
import Replicate, {
3+
ApiError,
4+
Model,
5+
Prediction,
6+
validateWebhook,
7+
} from "replicate";
38
import nock from "nock";
49
import fetch from "cross-fetch";
510

@@ -980,5 +985,37 @@ describe("Replicate client", () => {
980985
});
981986
});
982987

988+
describe("webhooks.default.secret.get", () => {
989+
test("Calls the correct API route", async () => {
990+
nock(BASE_URL).get("/webhooks/default/secret").reply(200, {
991+
key: "whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH",
992+
});
993+
994+
const secret = await client.webhooks.default.secret.get();
995+
expect(secret.key).toBe("whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH");
996+
});
997+
998+
test("Can be used to validate webhook", () => {
999+
const secret = "whsec_5WbX5kEWLlfzsGNjH64I8lOOqUB6e8FH";
1000+
1001+
const request = new Request("http://test.host/webhook", {
1002+
method: "POST",
1003+
headers: {
1004+
"Content-Type": "application/json",
1005+
"Webhook-ID": "123",
1006+
"Webhook-Timestamp": "1707329251",
1007+
"Webhook-Signature":
1008+
"v1,3Bh4jH4/KdieFo0oCZw+piC59XQiH4s1ySUQz+FgJqI=",
1009+
},
1010+
body: JSON.stringify({ event: "output", data: "Hello, world!" }),
1011+
});
1012+
1013+
const isValid = validateWebhook(request, secret);
1014+
expect(isValid).toBe(true);
1015+
});
1016+
1017+
// Add more tests for error handling, edge cases, etc.
1018+
});
1019+
9831020
// Continue with tests for other methods
9841021
});

lib/util.js

+73-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,77 @@
1+
const crypto = require("node:crypto");
2+
13
const ApiError = require("./error");
24

5+
/**
6+
* @see {@link validateWebhook}
7+
* @overload
8+
* @param {object} requestData - The request data
9+
* @param {string} requestData.id - The webhook ID header from the incoming request.
10+
* @param {string} requestData.timestamp - The webhook timestamp header from the incoming request.
11+
* @param {string} requestData.body - The raw body of the incoming webhook request.
12+
* @param {string} requestData.secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method.
13+
* @param {string} requestData.signature - The webhook signature header from the incoming request, comprising one or more space-delimited signatures.
14+
*/
15+
16+
/**
17+
* @see {@link validateWebhook}
18+
* @overload
19+
* @param {object} requestData - The request object
20+
* @param {object} requestData.headers - The request headers
21+
* @param {string} requestData.headers["webhook-id"] - The webhook ID header from the incoming request
22+
* @param {string} requestData.headers["webhook-timestamp"] - The webhook timestamp header from the incoming request
23+
* @param {string} requestData.headers["webhook-signature"] - The webhook signature header from the incoming request, comprising one or more space-delimited signatures
24+
* @param {string} requestData.body - The raw body of the incoming webhook request
25+
* @param {string} secret - The webhook secret, obtained from `replicate.webhooks.defaul.secret` method
26+
*/
27+
28+
/**
29+
* Validate a webhook signature
30+
*
31+
* @returns {boolean} - True if the signature is valid
32+
* @throws {Error} - If the request is missing required headers, body, or secret
33+
*/
34+
function validateWebhook(requestData, secret) {
35+
let { id, timestamp, body, signature } = requestData;
36+
const signingSecret = secret || requestData.secret;
37+
38+
if (requestData && requestData.headers && requestData.body) {
39+
id = requestData.headers.get("webhook-id");
40+
timestamp = requestData.headers.get("webhook-timestamp");
41+
signature = requestData.headers.get("webhook-signature");
42+
body = requestData.body;
43+
}
44+
45+
if (!id || !timestamp || !signature) {
46+
throw new Error("Missing required webhook headers");
47+
}
48+
49+
if (!body) {
50+
throw new Error("Missing required body");
51+
}
52+
53+
if (!signingSecret) {
54+
throw new Error("Missing required secret");
55+
}
56+
57+
const signedContent = `${id}.${timestamp}.${body}`;
58+
59+
const secretBytes = Buffer.from(signingSecret.split("_")[1], "base64");
60+
61+
const computedSignature = crypto
62+
.createHmac("sha256", secretBytes)
63+
.update(signedContent)
64+
.digest("base64");
65+
66+
const expectedSignatures = signature
67+
.split(" ")
68+
.map((sig) => sig.split(",")[1]);
69+
70+
return expectedSignatures.some(
71+
(expectedSignature) => expectedSignature === computedSignature
72+
);
73+
}
74+
375
/**
476
* Automatically retry a request if it fails with an appropriate status code.
577
*
@@ -68,4 +140,4 @@ async function withAutomaticRetries(request, options = {}) {
68140
return request();
69141
}
70142

71-
module.exports = { withAutomaticRetries };
143+
module.exports = { validateWebhook, withAutomaticRetries };

lib/webhooks.js

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
/**
2+
* Get the default webhook signing secret
3+
*
4+
* @returns {Promise<object>} Resolves with the signing secret for the default webhook
5+
*/
6+
async function getDefaultWebhookSecret() {
7+
const response = await this.request("/webhooks/default/secret", {
8+
method: "GET",
9+
});
10+
11+
return response.json();
12+
}
13+
14+
module.exports = {
15+
default: {
16+
secret: {
17+
get: getDefaultWebhookSecret,
18+
},
19+
},
20+
};

0 commit comments

Comments
 (0)