Skip to content

add custom headers on initial _startOrAuth call #318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

23 changes: 23 additions & 0 deletions src/client/sse.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,29 @@ describe("SSEClientTransport", () => {
expect(mockAuthProvider.tokens).toHaveBeenCalled();
});

it("attaches custom header from provider on initial SSE connection", async () => {
mockAuthProvider.tokens.mockResolvedValue({
access_token: "test-token",
token_type: "Bearer"
});
const customHeaders = {
"X-Custom-Header": "custom-value",
};

transport = new SSEClientTransport(resourceBaseUrl, {
authProvider: mockAuthProvider,
requestInit: {
headers: customHeaders,
},
});

await transport.start();

expect(lastServerRequest.headers.authorization).toBe("Bearer test-token");
expect(lastServerRequest.headers["x-custom-header"]).toBe("custom-value");
expect(mockAuthProvider.tokens).toHaveBeenCalled();
});

it("attaches auth header from provider on POST requests", async () => {
mockAuthProvider.tokens.mockResolvedValue({
access_token: "test-token",
Expand Down
23 changes: 10 additions & 13 deletions src/client/sse.ts
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,8 @@ export class SSEClientTransport implements Transport {
return await this._startOrAuth();
}

private async _commonHeaders(): Promise<HeadersInit> {
const headers = {
...this._requestInit?.headers,
} as HeadersInit & Record<string, string>;
private async _commonHeaders(): Promise<Headers> {
const headers: HeadersInit = {};
if (this._authProvider) {
const tokens = await this._authProvider.tokens();
if (tokens) {
Expand All @@ -120,24 +118,24 @@ export class SSEClientTransport implements Transport {
headers["mcp-protocol-version"] = this._protocolVersion;
}

return headers;
return new Headers(
{ ...headers, ...this._requestInit?.headers }
);
}

private _startOrAuth(): Promise<void> {
const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch
const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typeof fetch
return new Promise((resolve, reject) => {
this._eventSource = new EventSource(
this._url.href,
{
...this._eventSourceInit,
fetch: async (url, init) => {
const headers = await this._commonHeaders()
const headers = await this._commonHeaders();
headers.set("Accept", "text/event-stream");
const response = await fetchImpl(url, {
...init,
headers: new Headers({
...headers,
Accept: "text/event-stream"
})
headers,
})

if (response.status === 401 && response.headers.has('www-authenticate')) {
Expand Down Expand Up @@ -238,8 +236,7 @@ const fetchImpl = (this?._eventSourceInit?.fetch ?? this._fetch ?? fetch) as typ
}

try {
const commonHeaders = await this._commonHeaders();
const headers = new Headers(commonHeaders);
const headers = await this._commonHeaders();
headers.set("content-type", "application/json");
const init = {
...this._requestInit,
Expand Down