From 8027ae15938fde09e67d342c29cceed3ead1df52 Mon Sep 17 00:00:00 2001 From: Justin Spahr-Summers Date: Mon, 7 Apr 2025 10:47:37 +0100 Subject: [PATCH] Revert "feat: support extending McpServer with authorization" --- .prettierrc | 8 - jest.config.js | 1 - package.json | 1 - src/inMemory.ts | 1 - src/server/mcp.test.ts | 404 +++------------------------------------- src/server/mcp.ts | 29 +-- src/server/sse.ts | 1 - src/server/stdio.ts | 3 +- src/shared/protocol.ts | 8 +- src/shared/transport.ts | 5 - 10 files changed, 42 insertions(+), 419 deletions(-) delete mode 100644 .prettierrc diff --git a/.prettierrc b/.prettierrc deleted file mode 100644 index 4379c748..00000000 --- a/.prettierrc +++ /dev/null @@ -1,8 +0,0 @@ -{ - "printWidth": 80, - "tabWidth": 2, - "trailingComma": "all", - "jsxBracketSameLine": true, - "semi": true, - "singleQuote": false -} diff --git a/jest.config.js b/jest.config.js index a0021104..f8f621c8 100644 --- a/jest.config.js +++ b/jest.config.js @@ -12,6 +12,5 @@ export default { transformIgnorePatterns: [ "/node_modules/(?!eventsource)/" ], - collectCoverageFrom: ["src/**/*.ts"], testPathIgnorePatterns: ["/node_modules/", "/dist/"], }; diff --git a/package.json b/package.json index dfaa10cd..86fd9d6d 100644 --- a/package.json +++ b/package.json @@ -41,7 +41,6 @@ "prepack": "npm run build:esm && npm run build:cjs", "lint": "eslint src/", "test": "jest", - "coverage": "jest --coverage", "start": "npm run server", "server": "tsx watch --clear-screen=false src/cli.ts server", "client": "tsx src/cli.ts client" diff --git a/src/inMemory.ts b/src/inMemory.ts index 65915baa..106a9e7e 100644 --- a/src/inMemory.ts +++ b/src/inMemory.ts @@ -12,7 +12,6 @@ export class InMemoryTransport implements Transport { onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; sessionId?: string; - user?: unknown; /** * Creates a pair of linked in-memory transports that can communicate with each other. One should be passed to a Client and one to a Server. diff --git a/src/server/mcp.test.ts b/src/server/mcp.test.ts index 08518b20..2e91a568 100644 --- a/src/server/mcp.test.ts +++ b/src/server/mcp.test.ts @@ -1,8 +1,7 @@ -import { McpServer, ToolCallback } from "./mcp.js"; +import { McpServer } from "./mcp.js"; import { Client } from "../client/index.js"; import { InMemoryTransport } from "../inMemory.js"; -import { z, ZodRawShape } from "zod"; -import { zodToJsonSchema } from "zod-to-json-schema"; +import { z } from "zod"; import { ListToolsResultSchema, CallToolResultSchema, @@ -12,16 +11,10 @@ import { ListPromptsResultSchema, GetPromptResultSchema, CompleteResultSchema, - CallToolRequestSchema, - CallToolRequest, - ListToolsRequestSchema, - ListToolsResult, - Tool, } from "../types.js"; import { ResourceTemplate } from "./mcp.js"; import { completable } from "./completable.js"; import { UriTemplate } from "../shared/uriTemplate.js"; -import { RequestHandlerExtra } from "../shared/protocol.js"; describe("McpServer", () => { test("should expose underlying Server instance", () => { @@ -325,7 +318,7 @@ describe("tool()", () => { // This should succeed mcpServer.tool("tool1", () => ({ content: [] })); - + // This should also succeed and not throw about request handlers mcpServer.tool("tool2", () => ({ content: [] })); }); @@ -361,8 +354,7 @@ describe("tool()", () => { }; }); - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); + const [clientTransport, serverTransport] = InMemoryTransport.createLinkedPair(); // Set a test sessionId on the server transport serverTransport.sessionId = "test-session-123"; @@ -823,7 +815,7 @@ describe("resource()", () => { }, ], })); - + // This should also succeed and not throw about request handlers mcpServer.resource("resource2", "test://resource2", async () => ({ contents: [ @@ -1329,7 +1321,7 @@ describe("prompt()", () => { }, ], })); - + // This should also succeed and not throw about request handlers mcpServer.prompt("prompt2", async () => ({ messages: [ @@ -1351,17 +1343,19 @@ describe("prompt()", () => { }); // This should succeed - mcpServer.prompt("echo", { message: z.string() }, ({ message }) => ({ - messages: [ - { + mcpServer.prompt( + "echo", + { message: z.string() }, + ({ message }) => ({ + messages: [{ role: "user", content: { type: "text", - text: `Please process this message: ${message}`, - }, - }, - ], - })); + text: `Please process this message: ${message}` + } + }] + }) + ); }); test("should allow registering both resources and prompts with completion handlers", () => { @@ -1394,16 +1388,14 @@ describe("prompt()", () => { "echo", { message: completable(z.string(), () => ["hello", "world"]) }, ({ message }) => ({ - messages: [ - { - role: "user", - content: { - type: "text", - text: `Please process this message: ${message}`, - }, - }, - ], - }), + messages: [{ + role: "user", + content: { + type: "text", + text: `Please process this message: ${message}` + } + }] + }) ); }); @@ -1590,351 +1582,3 @@ describe("prompt()", () => { expect(result.completion.total).toBe(1); }); }); - -describe("McpServer with Auth Extension", () => { - type SessionUser = { - role: string; - [key: string]: unknown; - }; - - type AccessPolicy = { - allow?: { - roles?: string[]; - }; - deny?: { - roles?: string[]; - }; - }; - - type RegisteredToolWithAuth = { - description?: string; - inputSchema?: z.ZodObject; - callback: ToolCallback; - accessPolicy?: AccessPolicy; - }; - - // Just a simple extension to McpServer that adds support for access policies on tools - class McpServerWithAuth extends McpServer { - protected override _registeredTools: { - [name: string]: RegisteredToolWithAuth; - } = {}; - checkPermissions(user?: SessionUser, policy?: AccessPolicy): boolean { - if (!policy) { - return true; - } - - if (!user) { - return false; - } - - // Check deny rules first - if (policy.deny) { - // Check denied roles - if (policy.deny.roles?.includes(user.role)) { - return false; - } - } - - // Check allow rules - if (policy.allow) { - let isAllowed = false; - - // If no allow rules are specified, default to allowed - if (!policy.allow.roles) { - isAllowed = true; - } else { - // Check allowed roles - if (policy.allow.roles?.includes(user.role)) { - isAllowed = true; - } - } - - return isAllowed; - } - - // If no rules specified, default to allowed - return true; - } - - override tool( - name: string, - cb: ToolCallback, - accessPolicy?: AccessPolicy, - ): void; - override tool( - name: string, - description: string, - cb: ToolCallback, - accessPolicy?: AccessPolicy, - ): void; - override tool( - name: string, - paramsSchema: Args, - cb: ToolCallback, - accessPolicy?: AccessPolicy, - ): void; - override tool( - name: string, - description: string, - paramsSchema: Args, - cb: ToolCallback, - accessPolicy?: AccessPolicy, - ): void; - override tool(name: string, ...rest: unknown[]): void { - let description: string | undefined; - let paramsSchema: ZodRawShape | undefined; - let accessPolicy: AccessPolicy | undefined; - let cb: ToolCallback; - - // Parse arguments based on their types - if (typeof rest[0] === "function") { - // Case: tool(name, cb, accessPolicy?) - cb = rest[0] as ToolCallback; - accessPolicy = rest[1] as AccessPolicy | undefined; - } else if (typeof rest[0] === "string") { - // Cases with description - description = rest[0]; - if (typeof rest[1] === "function") { - // Case: tool(name, description, cb, accessPolicy?) - cb = rest[1] as ToolCallback; - accessPolicy = rest[2] as AccessPolicy | undefined; - } else { - // Case: tool(name, description, paramsSchema, cb, accessPolicy?) - paramsSchema = rest[1] as ZodRawShape; - cb = rest[2] as ToolCallback; - accessPolicy = rest[3] as AccessPolicy | undefined; - } - } else { - // Case: tool(name, paramsSchema, cb, accessPolicy?) - paramsSchema = rest[0] as ZodRawShape; - cb = rest[1] as ToolCallback; - accessPolicy = rest[2] as AccessPolicy | undefined; - } - - // Register with base class - const args: unknown[] = [name]; - if (description) args.push(description); - if (paramsSchema) args.push(paramsSchema); - args.push(cb); - - // Set up request handlers if not already initialized - if (!this._toolHandlersInitialized) { - this.server.assertCanSetRequestHandler( - CallToolRequestSchema.shape.method.value, - ); - this.server.assertCanSetRequestHandler( - ListToolsRequestSchema.shape.method.value, - ); - this.server.registerCapabilities({ tools: {} }); - - // Add ListToolsRequestSchema handler - this.server.setRequestHandler( - ListToolsRequestSchema, - (request, extra): ListToolsResult => { - const user = extra.user as SessionUser | undefined; - - // Filter tools based on permissions - const accessibleTools = Object.entries(this._registeredTools) - .filter(([_, tool]) => - this.checkPermissions(user, tool.accessPolicy), - ) - .map( - ([name, tool]): Tool => ({ - name, - description: tool.description, - inputSchema: tool.inputSchema - ? (zodToJsonSchema(tool.inputSchema, { - strictUnions: true, - }) as Tool["inputSchema"]) - : { type: "object" }, - }), - ); - - return { tools: accessibleTools }; - }, - ); - - this.server.setRequestHandler( - CallToolRequestSchema, - async (request: CallToolRequest, extra: RequestHandlerExtra) => { - const tool = this._registeredTools[request.params.name]; - if (!tool) { - throw new Error(`Tool ${request.params.name} not found`); - } - - if ( - !this.checkPermissions( - extra.user as SessionUser, - tool.accessPolicy, - ) - ) { - throw new Error(`Access denied for tool: ${request.params.name}`); - } - - if (tool.inputSchema) { - const parseResult = await tool.inputSchema.safeParseAsync( - request.params.arguments, - ); - if (!parseResult.success) { - throw new Error( - `Invalid arguments for tool ${request.params.name}: ${parseResult.error.message}`, - ); - } - - const args = parseResult.data; - const cb = tool.callback as ToolCallback; - return await Promise.resolve(cb(args, extra)); - } else { - const cb = tool.callback as ToolCallback; - return await Promise.resolve(cb(extra)); - } - }, - ); - this._toolHandlersInitialized = true; - } - - McpServer.prototype.tool.apply( - this, - args as Parameters, - ); - this._registeredTools[name].accessPolicy = accessPolicy; - } - } - - const mcpServer = new McpServerWithAuth({ - name: "test server with auth", - version: "1.0", - }); - const client = new Client({ - name: "test client", - version: "1.0", - }); - - mcpServer.tool("public-tool", async () => ({ - content: [ - { - type: "text", - text: "Public tool response", - }, - ], - })); - - mcpServer.tool( - "protected-tool", - async () => ({ - content: [ - { - type: "text", - text: "Protected tool response", - }, - ], - }), - { - allow: { - roles: ["admin"], - }, - }, - ); - - test("should public tools work with list and call when unauthenticated", async () => { - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // Public tool should be accessible - const result = await client.request( - { method: "tools/list" }, - ListToolsResultSchema, - ); - expect(result.tools).toHaveLength(1); - expect(result.tools[0].name).toBe("public-tool"); - const response = await client.request( - { - method: "tools/call", - params: { - name: "public-tool", - arguments: {}, - }, - }, - CallToolResultSchema, - ); - expect(response.content).toEqual([ - { - type: "text", - text: "Public tool response", - }, - ]); - }); - - test("should public tools work with list and call when unauthorized", async () => { - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - // Protected tool should be inaccessible when authenticated as a non-admin user - serverTransport.user = { role: "member" }; - const result = await client.request( - { method: "tools/list" }, - ListToolsResultSchema, - ); - expect(result.tools).toHaveLength(1); - expect(result.tools[0].name).toBe("public-tool"); - const response = await client.request( - { - method: "tools/call", - params: { - name: "public-tool", - arguments: {}, - }, - }, - CallToolResultSchema, - ); - expect(response.content).toEqual([ - { - type: "text", - text: "Public tool response", - }, - ]); - }); - - test("should protected tools work with list and call when authorized", async () => { - const [clientTransport, serverTransport] = - InMemoryTransport.createLinkedPair(); - await Promise.all([ - client.connect(clientTransport), - mcpServer.server.connect(serverTransport), - ]); - - serverTransport.user = { role: "admin" }; - const result = await client.request( - { method: "tools/list" }, - ListToolsResultSchema, - ); - expect(result.tools).toHaveLength(2); - expect(result.tools[0].name).toBe("public-tool"); - expect(result.tools[1].name).toBe("protected-tool"); - - const response = await client.request( - { - method: "tools/call", - params: { - name: "protected-tool", - arguments: {}, - }, - }, - CallToolResultSchema, - ); - expect(response.content).toEqual([ - { - type: "text", - text: "Protected tool response", - }, - ]); - }); -}); diff --git a/src/server/mcp.ts b/src/server/mcp.ts index ed93256f..8f4a909c 100644 --- a/src/server/mcp.ts +++ b/src/server/mcp.ts @@ -54,17 +54,12 @@ export class McpServer { */ public readonly server: Server; - protected _registeredResources: { [uri: string]: RegisteredResource } = {}; - protected _registeredResourceTemplates: { + private _registeredResources: { [uri: string]: RegisteredResource } = {}; + private _registeredResourceTemplates: { [name: string]: RegisteredResourceTemplate; } = {}; - protected _registeredTools: { [name: string]: RegisteredTool } = {}; - protected _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; - - protected _toolHandlersInitialized = false; - protected _completionHandlerInitialized = false; - protected _resourceHandlersInitialized = false; - protected _promptHandlersInitialized = false; + private _registeredTools: { [name: string]: RegisteredTool } = {}; + private _registeredPrompts: { [name: string]: RegisteredPrompt } = {}; constructor(serverInfo: Implementation, options?: ServerOptions) { this.server = new Server(serverInfo, options); @@ -86,11 +81,13 @@ export class McpServer { await this.server.close(); } + private _toolHandlersInitialized = false; + private setToolRequestHandlers() { if (this._toolHandlersInitialized) { return; } - + this.server.assertCanSetRequestHandler( ListToolsRequestSchema.shape.method.value, ); @@ -180,6 +177,8 @@ export class McpServer { this._toolHandlersInitialized = true; } + private _completionHandlerInitialized = false; + private setCompletionRequestHandler() { if (this._completionHandlerInitialized) { return; @@ -268,6 +267,8 @@ export class McpServer { return createCompletionResult(suggestions); } + private _resourceHandlersInitialized = false; + private setResourceRequestHandlers() { if (this._resourceHandlersInitialized) { return; @@ -365,10 +366,12 @@ export class McpServer { ); this.setCompletionRequestHandler(); - + this._resourceHandlersInitialized = true; } + private _promptHandlersInitialized = false; + private setPromptRequestHandlers() { if (this._promptHandlersInitialized) { return; @@ -435,7 +438,7 @@ export class McpServer { ); this.setCompletionRequestHandler(); - + this._promptHandlersInitialized = true; } @@ -767,7 +770,7 @@ type RegisteredPrompt = { callback: PromptCallback; }; -export function promptArgumentsFromSchema( +function promptArgumentsFromSchema( schema: ZodObject, ): PromptArgument[] { return Object.entries(schema.shape).map( diff --git a/src/server/sse.ts b/src/server/sse.ts index 73603021..84c1cbb9 100644 --- a/src/server/sse.ts +++ b/src/server/sse.ts @@ -19,7 +19,6 @@ export class SSEServerTransport implements Transport { onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; - user?: unknown; /** * Creates a new SSE server transport, which will direct the client to POST messages to the relative or absolute URL identified by `_endpoint`. diff --git a/src/server/stdio.ts b/src/server/stdio.ts index 7c3b21c6..30c80012 100644 --- a/src/server/stdio.ts +++ b/src/server/stdio.ts @@ -21,7 +21,6 @@ export class StdioServerTransport implements Transport { onclose?: () => void; onerror?: (error: Error) => void; onmessage?: (message: JSONRPCMessage) => void; - user?: unknown; // Arrow functions to bind `this` properly, while maintaining function identity. _ondata = (chunk: Buffer) => { @@ -74,7 +73,7 @@ export class StdioServerTransport implements Transport { // This prevents interfering with other parts of the application that might be using stdin this._stdin.pause(); } - + // Clear the buffer and notify closure this._readBuffer.clear(); this.onclose?.(); diff --git a/src/shared/protocol.ts b/src/shared/protocol.ts index 5540d4fe..a6e47184 100644 --- a/src/shared/protocol.ts +++ b/src/shared/protocol.ts @@ -93,11 +93,6 @@ export type RequestHandlerExtra = { * The session ID from the transport, if available. */ sessionId?: string; - - /** - * The authenticated user, if available. - */ - user?: unknown; }; /** @@ -324,7 +319,6 @@ export abstract class Protocol< const extra: RequestHandlerExtra = { signal: abortController.signal, sessionId: this._transport?.sessionId, - user: this._transport?.user, }; // Starting with Promise.resolve() puts any synchronous errors into the monad as well. @@ -370,7 +364,7 @@ export abstract class Protocol< private _onprogress(notification: ProgressNotification): void { const { progressToken, ...params } = notification.params; const messageId = Number(progressToken); - + const handler = this._progressHandlers.get(messageId); if (!handler) { this._onerror(new Error(`Received a progress notification for an unknown token: ${JSON.stringify(notification)}`)); diff --git a/src/shared/transport.ts b/src/shared/transport.ts index 88cea7f2..b80e2a51 100644 --- a/src/shared/transport.ts +++ b/src/shared/transport.ts @@ -46,9 +46,4 @@ export interface Transport { * The session ID generated for this connection. */ sessionId?: string; - - /** - * The authenticated user for this transport session. - */ - user?: unknown; }