From 3571fc353f2e4c80c4667d3f8252688c5e2ce960 Mon Sep 17 00:00:00 2001
From: wangshijun <wangshijun2010@gmail.com>
Date: Tue, 1 Apr 2025 12:21:08 +0800
Subject: [PATCH 1/4] feat: support extending McpServer with subclass

---
 package-lock.json       |  4 ++--
 src/server/mcp.ts       | 29 +++++++++++++----------------
 src/shared/protocol.ts  |  8 +++++++-
 src/shared/transport.ts |  5 +++++
 4 files changed, 27 insertions(+), 19 deletions(-)

diff --git a/package-lock.json b/package-lock.json
index 73f1cbba..8338e3c4 100644
--- a/package-lock.json
+++ b/package-lock.json
@@ -1,12 +1,12 @@
 {
   "name": "@modelcontextprotocol/sdk",
-  "version": "1.7.0",
+  "version": "1.8.0",
   "lockfileVersion": 3,
   "requires": true,
   "packages": {
     "": {
       "name": "@modelcontextprotocol/sdk",
-      "version": "1.7.0",
+      "version": "1.8.0",
       "license": "MIT",
       "dependencies": {
         "content-type": "^1.0.5",
diff --git a/src/server/mcp.ts b/src/server/mcp.ts
index 8f4a909c..ed93256f 100644
--- a/src/server/mcp.ts
+++ b/src/server/mcp.ts
@@ -54,12 +54,17 @@ export class McpServer {
    */
   public readonly server: Server;
 
-  private _registeredResources: { [uri: string]: RegisteredResource } = {};
-  private _registeredResourceTemplates: {
+  protected _registeredResources: { [uri: string]: RegisteredResource } = {};
+  protected _registeredResourceTemplates: {
     [name: string]: RegisteredResourceTemplate;
   } = {};
-  private _registeredTools: { [name: string]: RegisteredTool } = {};
-  private _registeredPrompts: { [name: string]: RegisteredPrompt } = {};
+  protected _registeredTools: { [name: string]: RegisteredTool } = {};
+  protected _registeredPrompts: { [name: string]: RegisteredPrompt } = {};
+
+  protected _toolHandlersInitialized = false;
+  protected _completionHandlerInitialized = false;
+  protected _resourceHandlersInitialized = false;
+  protected _promptHandlersInitialized = false;
 
   constructor(serverInfo: Implementation, options?: ServerOptions) {
     this.server = new Server(serverInfo, options);
@@ -81,13 +86,11 @@ export class McpServer {
     await this.server.close();
   }
 
-  private _toolHandlersInitialized = false;
-
   private setToolRequestHandlers() {
     if (this._toolHandlersInitialized) {
       return;
     }
-    
+
     this.server.assertCanSetRequestHandler(
       ListToolsRequestSchema.shape.method.value,
     );
@@ -177,8 +180,6 @@ export class McpServer {
     this._toolHandlersInitialized = true;
   }
 
-  private _completionHandlerInitialized = false;
-
   private setCompletionRequestHandler() {
     if (this._completionHandlerInitialized) {
       return;
@@ -267,8 +268,6 @@ export class McpServer {
     return createCompletionResult(suggestions);
   }
 
-  private _resourceHandlersInitialized = false;
-
   private setResourceRequestHandlers() {
     if (this._resourceHandlersInitialized) {
       return;
@@ -366,12 +365,10 @@ export class McpServer {
     );
 
     this.setCompletionRequestHandler();
-    
+
     this._resourceHandlersInitialized = true;
   }
 
-  private _promptHandlersInitialized = false;
-
   private setPromptRequestHandlers() {
     if (this._promptHandlersInitialized) {
       return;
@@ -438,7 +435,7 @@ export class McpServer {
     );
 
     this.setCompletionRequestHandler();
-    
+
     this._promptHandlersInitialized = true;
   }
 
@@ -770,7 +767,7 @@ type RegisteredPrompt = {
   callback: PromptCallback<undefined | PromptArgsRawShape>;
 };
 
-function promptArgumentsFromSchema(
+export function promptArgumentsFromSchema(
   schema: ZodObject<PromptArgsRawShape>,
 ): PromptArgument[] {
   return Object.entries(schema.shape).map(
diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts
index a5b6ad51..641e30c2 100644
--- a/src/shared/protocol.ts
+++ b/src/shared/protocol.ts
@@ -93,6 +93,11 @@ export type RequestHandlerExtra = {
    * The session ID from the transport, if available.
    */
   sessionId?: string;
+
+  /**
+   * The authenticated user, if available.
+   */
+  user?: unknown;
 };
 
 /**
@@ -316,6 +321,7 @@ export abstract class Protocol<
     const extra: RequestHandlerExtra = {
       signal: abortController.signal,
       sessionId: this._transport?.sessionId,
+      user: this._transport?.user,
     };
 
     // Starting with Promise.resolve() puts any synchronous errors into the monad as well.
@@ -361,7 +367,7 @@ export abstract class Protocol<
   private _onprogress(notification: ProgressNotification): void {
     const { progressToken, ...params } = notification.params;
     const messageId = Number(progressToken);
-    
+
     const handler = this._progressHandlers.get(messageId);
     if (!handler) {
       this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`));
diff --git a/src/shared/transport.ts b/src/shared/transport.ts
index b80e2a51..88cea7f2 100644
--- a/src/shared/transport.ts
+++ b/src/shared/transport.ts
@@ -46,4 +46,9 @@ export interface Transport {
    * The session ID generated for this connection.
    */
   sessionId?: string;
+
+  /**
+   * The authenticated user for this transport session.
+   */
+  user?: unknown;
 }

From a3590b9bee01305923f31df4465085adebce3518 Mon Sep 17 00:00:00 2001
From: wangshijun <wangshijun2010@gmail.com>
Date: Sat, 5 Apr 2025 07:46:43 +0800
Subject: [PATCH 2/4] chore: add coverage npm script

---
 jest.config.js | 1 +
 package.json   | 1 +
 2 files changed, 2 insertions(+)

diff --git a/jest.config.js b/jest.config.js
index f8f621c8..a0021104 100644
--- a/jest.config.js
+++ b/jest.config.js
@@ -12,5 +12,6 @@ export default {
   transformIgnorePatterns: [
     "/node_modules/(?!eventsource)/"
   ],
+  collectCoverageFrom: ["src/**/*.ts"],
   testPathIgnorePatterns: ["/node_modules/", "/dist/"],
 };
diff --git a/package.json b/package.json
index e2d8b3d7..d14fd9e6 100644
--- a/package.json
+++ b/package.json
@@ -41,6 +41,7 @@
     "prepack": "npm run build:esm && npm run build:cjs",
     "lint": "eslint src/",
     "test": "jest",
+    "coverage": "jest --coverage",
     "start": "npm run server",
     "server": "tsx watch --clear-screen=false src/cli.ts server",
     "client": "tsx src/cli.ts client"

From 730cd670b008e50c522091232252d41e772930d5 Mon Sep 17 00:00:00 2001
From: wangshijun <wangshijun2010@gmail.com>
Date: Sat, 5 Apr 2025 07:47:04 +0800
Subject: [PATCH 3/4] chore: add prettierrc to enforce consistent coding style

---
 .prettierrc | 8 ++++++++
 1 file changed, 8 insertions(+)
 create mode 100644 .prettierrc

diff --git a/.prettierrc b/.prettierrc
new file mode 100644
index 00000000..e0180a40
--- /dev/null
+++ b/.prettierrc
@@ -0,0 +1,8 @@
+{
+  "printWidth": 120,
+  "tabWidth": 2,
+  "trailingComma": "all",
+  "jsxBracketSameLine": true,
+  "semi": true,
+  "singleQuote": false
+}

From 473dc6cad568d978fff0f0d7c26e4ff963192cff Mon Sep 17 00:00:00 2001
From: wangshijun <wangshijun2010@gmail.com>
Date: Sat, 5 Apr 2025 11:31:59 +0800
Subject: [PATCH 4/4] chore: add test case for McpServer with auth

---
 .prettierrc            |   2 +-
 src/inMemory.ts        |   1 +
 src/server/mcp.test.ts | 404 ++++++++++++++++++++++++++++++++++++++---
 src/server/sse.ts      |   1 +
 src/server/stdio.ts    |   3 +-
 5 files changed, 385 insertions(+), 26 deletions(-)

diff --git a/.prettierrc b/.prettierrc
index e0180a40..4379c748 100644
--- a/.prettierrc
+++ b/.prettierrc
@@ -1,5 +1,5 @@
 {
-  "printWidth": 120,
+  "printWidth": 80,
   "tabWidth": 2,
   "trailingComma": "all",
   "jsxBracketSameLine": true,
diff --git a/src/inMemory.ts b/src/inMemory.ts
index 106a9e7e..65915baa 100644
--- a/src/inMemory.ts
+++ b/src/inMemory.ts
@@ -12,6 +12,7 @@ export class InMemoryTransport implements Transport {
   onerror?: (error: Error) => void;
   onmessage?: (message: JSONRPCMessage) => void;
   sessionId?: string;
+  user?: unknown;
 
   /**
    * Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server.
diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts
index 2e91a568..08518b20 100644
--- a/src/server/mcp.test.ts
+++ b/src/server/mcp.test.ts
@@ -1,7 +1,8 @@
-import { McpServer } from "./mcp.js";
+import { McpServer, ToolCallback } from "./mcp.js";
 import { Client } from "../client/index.js";
 import { InMemoryTransport } from "../inMemory.js";
-import { z } from "zod";
+import { z, ZodRawShape } from "zod";
+import { zodToJsonSchema } from "zod-to-json-schema";
 import {
   ListToolsResultSchema,
   CallToolResultSchema,
@@ -11,10 +12,16 @@ import {
   ListPromptsResultSchema,
   GetPromptResultSchema,
   CompleteResultSchema,
+  CallToolRequestSchema,
+  CallToolRequest,
+  ListToolsRequestSchema,
+  ListToolsResult,
+  Tool,
 } from "../types.js";
 import { ResourceTemplate } from "./mcp.js";
 import { completable } from "./completable.js";
 import { UriTemplate } from "../shared/uriTemplate.js";
+import { RequestHandlerExtra } from "../shared/protocol.js";
 
 describe("McpServer", () => {
   test("should expose underlying Server instance", () => {
@@ -318,7 +325,7 @@ describe("tool()", () => {
 
     // This should succeed
     mcpServer.tool("tool1", () => ({ content: [] }));
-    
+
     // This should also succeed and not throw about request handlers
     mcpServer.tool("tool2", () => ({ content: [] }));
   });
@@ -354,7 +361,8 @@ describe("tool()", () => {
       };
     });
 
-    const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair();
+    const [clientTransport, serverTransport] =
+      InMemoryTransport.createLinkedPair();
     // Set a test sessionId on the server transport
     serverTransport.sessionId = "test-session-123";
 
@@ -815,7 +823,7 @@ describe("resource()", () => {
         },
       ],
     }));
-    
+
     // This should also succeed and not throw about request handlers
     mcpServer.resource("resource2", "test://resource2", async () => ({
       contents: [
@@ -1321,7 +1329,7 @@ describe("prompt()", () => {
         },
       ],
     }));
-    
+
     // This should also succeed and not throw about request handlers
     mcpServer.prompt("prompt2", async () => ({
       messages: [
@@ -1343,19 +1351,17 @@ describe("prompt()", () => {
     });
 
     // This should succeed
-    mcpServer.prompt(
-      "echo",
-      { message: z.string() },
-      ({ message }) => ({
-        messages: [{
+    mcpServer.prompt("echo", { message: z.string() }, ({ message }) => ({
+      messages: [
+        {
           role: "user",
           content: {
             type: "text",
-            text: `Please process this message: ${message}`
-          }
-        }]
-      })
-    );
+            text: `Please process this message: ${message}`,
+          },
+        },
+      ],
+    }));
   });
 
   test("should allow registering both resources and prompts with completion handlers", () => {
@@ -1388,14 +1394,16 @@ describe("prompt()", () => {
       "echo",
       { message: completable(z.string(), () => ["hello", "world"]) },
       ({ message }) => ({
-        messages: [{
-          role: "user",
-          content: {
-            type: "text",
-            text: `Please process this message: ${message}`
-          }
-        }]
-      })
+        messages: [
+          {
+            role: "user",
+            content: {
+              type: "text",
+              text: `Please process this message: ${message}`,
+            },
+          },
+        ],
+      }),
     );
   });
 
@@ -1582,3 +1590,351 @@ describe("prompt()", () => {
     expect(result.completion.total).toBe(1);
   });
 });
+
+describe("McpServer with Auth Extension", () => {
+  type SessionUser = {
+    role: string;
+    [key: string]: unknown;
+  };
+
+  type AccessPolicy = {
+    allow?: {
+      roles?: string[];
+    };
+    deny?: {
+      roles?: string[];
+    };
+  };
+
+  type RegisteredToolWithAuth = {
+    description?: string;
+    inputSchema?: z.ZodObject<ZodRawShape>;
+    callback: ToolCallback<ZodRawShape | undefined>;
+    accessPolicy?: AccessPolicy;
+  };
+
+  // Just a simple extension to McpServer that adds support for access policies on tools
+  class McpServerWithAuth extends McpServer {
+    protected override _registeredTools: {
+      [name: string]: RegisteredToolWithAuth;
+    } = {};
+    checkPermissions(user?: SessionUser, policy?: AccessPolicy): boolean {
+      if (!policy) {
+        return true;
+      }
+
+      if (!user) {
+        return false;
+      }
+
+      // Check deny rules first
+      if (policy.deny) {
+        // Check denied roles
+        if (policy.deny.roles?.includes(user.role)) {
+          return false;
+        }
+      }
+
+      // Check allow rules
+      if (policy.allow) {
+        let isAllowed = false;
+
+        // If no allow rules are specified, default to allowed
+        if (!policy.allow.roles) {
+          isAllowed = true;
+        } else {
+          // Check allowed roles
+          if (policy.allow.roles?.includes(user.role)) {
+            isAllowed = true;
+          }
+        }
+
+        return isAllowed;
+      }
+
+      // If no rules specified, default to allowed
+      return true;
+    }
+
+    override tool(
+      name: string,
+      cb: ToolCallback,
+      accessPolicy?: AccessPolicy,
+    ): void;
+    override tool(
+      name: string,
+      description: string,
+      cb: ToolCallback,
+      accessPolicy?: AccessPolicy,
+    ): void;
+    override tool<Args extends ZodRawShape>(
+      name: string,
+      paramsSchema: Args,
+      cb: ToolCallback<Args>,
+      accessPolicy?: AccessPolicy,
+    ): void;
+    override tool<Args extends ZodRawShape>(
+      name: string,
+      description: string,
+      paramsSchema: Args,
+      cb: ToolCallback<Args>,
+      accessPolicy?: AccessPolicy,
+    ): void;
+    override tool(name: string, ...rest: unknown[]): void {
+      let description: string | undefined;
+      let paramsSchema: ZodRawShape | undefined;
+      let accessPolicy: AccessPolicy | undefined;
+      let cb: ToolCallback<ZodRawShape | undefined>;
+
+      // Parse arguments based on their types
+      if (typeof rest[0] === "function") {
+        // Case: tool(name, cb, accessPolicy?)
+        cb = rest[0] as ToolCallback<ZodRawShape | undefined>;
+        accessPolicy = rest[1] as AccessPolicy | undefined;
+      } else if (typeof rest[0] === "string") {
+        // Cases with description
+        description = rest[0];
+        if (typeof rest[1] === "function") {
+          // Case: tool(name, description, cb, accessPolicy?)
+          cb = rest[1] as ToolCallback<ZodRawShape | undefined>;
+          accessPolicy = rest[2] as AccessPolicy | undefined;
+        } else {
+          // Case: tool(name, description, paramsSchema, cb, accessPolicy?)
+          paramsSchema = rest[1] as ZodRawShape;
+          cb = rest[2] as ToolCallback<ZodRawShape>;
+          accessPolicy = rest[3] as AccessPolicy | undefined;
+        }
+      } else {
+        // Case: tool(name, paramsSchema, cb, accessPolicy?)
+        paramsSchema = rest[0] as ZodRawShape;
+        cb = rest[1] as ToolCallback<ZodRawShape>;
+        accessPolicy = rest[2] as AccessPolicy | undefined;
+      }
+
+      // Register with base class
+      const args: unknown[] = [name];
+      if (description) args.push(description);
+      if (paramsSchema) args.push(paramsSchema);
+      args.push(cb);
+
+      // Set up request handlers if not already initialized
+      if (!this._toolHandlersInitialized) {
+        this.server.assertCanSetRequestHandler(
+          CallToolRequestSchema.shape.method.value,
+        );
+        this.server.assertCanSetRequestHandler(
+          ListToolsRequestSchema.shape.method.value,
+        );
+        this.server.registerCapabilities({ tools: {} });
+
+        // Add ListToolsRequestSchema handler
+        this.server.setRequestHandler(
+          ListToolsRequestSchema,
+          (request, extra): ListToolsResult => {
+            const user = extra.user as SessionUser | undefined;
+
+            // Filter tools based on permissions
+            const accessibleTools = Object.entries(this._registeredTools)
+              .filter(([_, tool]) =>
+                this.checkPermissions(user, tool.accessPolicy),
+              )
+              .map(
+                ([name, tool]): Tool => ({
+                  name,
+                  description: tool.description,
+                  inputSchema: tool.inputSchema
+                    ? (zodToJsonSchema(tool.inputSchema, {
+                        strictUnions: true,
+                      }) as Tool["inputSchema"])
+                    : { type: "object" },
+                }),
+              );
+
+            return { tools: accessibleTools };
+          },
+        );
+
+        this.server.setRequestHandler(
+          CallToolRequestSchema,
+          async (request: CallToolRequest, extra: RequestHandlerExtra) => {
+            const tool = this._registeredTools[request.params.name];
+            if (!tool) {
+              throw new Error(`Tool ${request.params.name} not found`);
+            }
+
+            if (
+              !this.checkPermissions(
+                extra.user as SessionUser,
+                tool.accessPolicy,
+              )
+            ) {
+              throw new Error(`Access denied for tool: ${request.params.name}`);
+            }
+
+            if (tool.inputSchema) {
+              const parseResult = await tool.inputSchema.safeParseAsync(
+                request.params.arguments,
+              );
+              if (!parseResult.success) {
+                throw new Error(
+                  `Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}`,
+                );
+              }
+
+              const args = parseResult.data;
+              const cb = tool.callback as ToolCallback<ZodRawShape>;
+              return await Promise.resolve(cb(args, extra));
+            } else {
+              const cb = tool.callback as ToolCallback<undefined>;
+              return await Promise.resolve(cb(extra));
+            }
+          },
+        );
+        this._toolHandlersInitialized = true;
+      }
+
+      McpServer.prototype.tool.apply(
+        this,
+        args as Parameters<typeof McpServer.prototype.tool>,
+      );
+      this._registeredTools[name].accessPolicy = accessPolicy;
+    }
+  }
+
+  const mcpServer = new McpServerWithAuth({
+    name: "test server with auth",
+    version: "1.0",
+  });
+  const client = new Client({
+    name: "test client",
+    version: "1.0",
+  });
+
+  mcpServer.tool("public-tool", async () => ({
+    content: [
+      {
+        type: "text",
+        text: "Public tool response",
+      },
+    ],
+  }));
+
+  mcpServer.tool(
+    "protected-tool",
+    async () => ({
+      content: [
+        {
+          type: "text",
+          text: "Protected tool response",
+        },
+      ],
+    }),
+    {
+      allow: {
+        roles: ["admin"],
+      },
+    },
+  );
+
+  test("should public tools work with list and call when unauthenticated", async () => {
+    const [clientTransport, serverTransport] =
+      InMemoryTransport.createLinkedPair();
+    await Promise.all([
+      client.connect(clientTransport),
+      mcpServer.server.connect(serverTransport),
+    ]);
+
+    // Public tool should be accessible
+    const result = await client.request(
+      { method: "tools/list" },
+      ListToolsResultSchema,
+    );
+    expect(result.tools).toHaveLength(1);
+    expect(result.tools[0].name).toBe("public-tool");
+    const response = await client.request(
+      {
+        method: "tools/call",
+        params: {
+          name: "public-tool",
+          arguments: {},
+        },
+      },
+      CallToolResultSchema,
+    );
+    expect(response.content).toEqual([
+      {
+        type: "text",
+        text: "Public tool response",
+      },
+    ]);
+  });
+
+  test("should public tools work with list and call when unauthorized", async () => {
+    const [clientTransport, serverTransport] =
+      InMemoryTransport.createLinkedPair();
+    await Promise.all([
+      client.connect(clientTransport),
+      mcpServer.server.connect(serverTransport),
+    ]);
+
+    // Protected tool should be inaccessible when authenticated as a non-admin user
+    serverTransport.user = { role: "member" };
+    const result = await client.request(
+      { method: "tools/list" },
+      ListToolsResultSchema,
+    );
+    expect(result.tools).toHaveLength(1);
+    expect(result.tools[0].name).toBe("public-tool");
+    const response = await client.request(
+      {
+        method: "tools/call",
+        params: {
+          name: "public-tool",
+          arguments: {},
+        },
+      },
+      CallToolResultSchema,
+    );
+    expect(response.content).toEqual([
+      {
+        type: "text",
+        text: "Public tool response",
+      },
+    ]);
+  });
+
+  test("should protected tools work with list and call when authorized", async () => {
+    const [clientTransport, serverTransport] =
+      InMemoryTransport.createLinkedPair();
+    await Promise.all([
+      client.connect(clientTransport),
+      mcpServer.server.connect(serverTransport),
+    ]);
+
+    serverTransport.user = { role: "admin" };
+    const result = await client.request(
+      { method: "tools/list" },
+      ListToolsResultSchema,
+    );
+    expect(result.tools).toHaveLength(2);
+    expect(result.tools[0].name).toBe("public-tool");
+    expect(result.tools[1].name).toBe("protected-tool");
+
+    const response = await client.request(
+      {
+        method: "tools/call",
+        params: {
+          name: "protected-tool",
+          arguments: {},
+        },
+      },
+      CallToolResultSchema,
+    );
+    expect(response.content).toEqual([
+      {
+        type: "text",
+        text: "Protected tool response",
+      },
+    ]);
+  });
+});
diff --git a/src/server/sse.ts b/src/server/sse.ts
index 84c1cbb9..73603021 100644
--- a/src/server/sse.ts
+++ b/src/server/sse.ts
@@ -19,6 +19,7 @@ export class SSEServerTransport implements Transport {
   onclose?: () => void;
   onerror?: (error: Error) => void;
   onmessage?: (message: JSONRPCMessage) => void;
+  user?: unknown;
 
   /**
    * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`.
diff --git a/src/server/stdio.ts b/src/server/stdio.ts
index 30c80012..7c3b21c6 100644
--- a/src/server/stdio.ts
+++ b/src/server/stdio.ts
@@ -21,6 +21,7 @@ export class StdioServerTransport implements Transport {
   onclose?: () => void;
   onerror?: (error: Error) => void;
   onmessage?: (message: JSONRPCMessage) => void;
+  user?: unknown;
 
   // Arrow functions to bind `this` properly, while maintaining function identity.
   _ondata = (chunk: Buffer) => {
@@ -73,7 +74,7 @@ export class StdioServerTransport implements Transport {
       // This prevents interfering with other parts of the application that might be using stdin
       this._stdin.pause();
     }
-    
+
     // Clear the buffer and notify closure
     this._readBuffer.clear();
     this.onclose?.();