From f45b05e3cf9d0e0f779da02ff33903899fba429e Mon Sep 17 00:00:00 2001 From: zhaoxin Date: Thu, 18 Dec 2025 11:17:14 +0800 Subject: [PATCH] mcp: add DisableListening option to StreamableClientTransport --- mcp/streamable.go | 69 ++++++++++++++++---- mcp/streamable_client_test.go | 118 ++++++++++++++++++++++++++++++++++ 2 files changed, 174 insertions(+), 13 deletions(-) diff --git a/mcp/streamable.go b/mcp/streamable.go index b4b2fa31..9bed4d96 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -1389,6 +1389,20 @@ type StreamableClientTransport struct { // It defaults to 5. To disable retries, use a negative number. MaxRetries int + // DisableListening disables receiving server-to-client notifications when no request is in flight. + // + // By default, the client establishes a standalone long-live GET HTTP connection to the server + // to receive server-initiated messages (like ToolListChangedNotification). + // https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server + // NOTICE: Even if continuous listening is enabled, the server may not support this feature. + // + // If false (default), the client will establish the standalone SSE stream. + // If true, the client will not establish the standalone SSE stream and will only receive + // responses to its own requests. + // + // Defaults to false to maintain backward compatibility with existing behavior. + DisableListening bool + // TODO(rfindley): propose exporting these. // If strict is set, the transport is in 'strict mode', where any violation // of the MCP spec causes a failure. @@ -1416,6 +1430,28 @@ var ( reconnectInitialDelay = 1 * time.Second ) +// WithDisableListening disables receiving server-to-client notifications when no request is in flight. +// +// By default, the client establishes a standalone long-live GET HTTP connection to the server +// to receive server-initiated messages. This function disables that behavior. +// +// If you want to disable continuous listening, you can either: +// +// transport := &mcp.StreamableClientTransport{ +// Endpoint: "http://localhost:8080/mcp", +// DisableListening: true, +// } +// +// Or use this convenience function: +// +// transport := &mcp.StreamableClientTransport{ +// Endpoint: "http://localhost:8080/mcp", +// } +// mcp.WithDisableListening(transport) +func WithDisableListening(transport *StreamableClientTransport) { + transport.DisableListening = true +} + // Connect implements the [Transport] interface. // // The resulting [Connection] writes messages via POST requests to the @@ -1453,16 +1489,17 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er // middleware), yet only cancel the standalone stream when the connection is closed. connCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) conn := &streamableClientConn{ - url: t.Endpoint, - client: client, - incoming: make(chan jsonrpc.Message, 10), - done: make(chan struct{}), - maxRetries: maxRetries, - strict: t.strict, - logger: ensureLogger(t.logger), // must be non-nil for safe logging - ctx: connCtx, - cancel: cancel, - failed: make(chan struct{}), + url: t.Endpoint, + client: client, + incoming: make(chan jsonrpc.Message, 10), + done: make(chan struct{}), + maxRetries: maxRetries, + strict: t.strict, + logger: ensureLogger(t.logger), // must be non-nil for safe logging + ctx: connCtx, + cancel: cancel, + failed: make(chan struct{}), + disableListening: t.DisableListening, } return conn, nil } @@ -1477,6 +1514,10 @@ type streamableClientConn struct { strict bool // from [StreamableClientTransport.strict] logger *slog.Logger // from [StreamableClientTransport.logger] + // disableListening controls whether to disable the standalone SSE stream + // for receiving server-to-client notifications when no request is in flight. + disableListening bool // from [StreamableClientTransport.DisableListening] + // Guard calls to Close, as it may be called multiple times. closeOnce sync.Once closeErr error @@ -1518,7 +1559,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { c.mu.Unlock() // Start the standalone SSE stream as soon as we have the initialized - // result. + // result, if continuous listening is enabled. // // § 2.2: The client MAY issue an HTTP GET to the MCP endpoint. This can be // used to open an SSE stream, allowing the server to communicate to the @@ -1528,9 +1569,11 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) { // initialized, we don't know whether the server requires a sessionID. // // § 2.5: A server using the Streamable HTTP transport MAY assign a session - // ID at initialization time, by including it in an Mcp-Session-Id header + // ID at initialization time, by including it in a Mcp-Session-Id header // on the HTTP response containing the InitializeResult. - c.connectStandaloneSSE() + if !c.disableListening { + c.connectStandaloneSSE() + } } func (c *streamableClientConn) connectStandaloneSSE() { diff --git a/mcp/streamable_client_test.go b/mcp/streamable_client_test.go index e2923325..1fbcb3d7 100644 --- a/mcp/streamable_client_test.go +++ b/mcp/streamable_client_test.go @@ -693,3 +693,121 @@ func TestStreamableClientTransientErrors(t *testing.T) { }) } } + +func TestStreamableClientDisableListening(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + disableListening bool + expectGETRequest bool + }{ + { + name: "default behavior (listening enabled)", + disableListening: false, + expectGETRequest: true, + }, + { + name: "listening disabled", + disableListening: true, + expectGETRequest: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + getRequestKey := streamableRequestKey{"GET", "123", "", ""} + + fake := &fakeStreamableServer{ + t: t, + responses: fakeResponses{ + {"POST", "", methodInitialize, ""}: { + header: header{ + "Content-Type": "application/json", + sessionIDHeader: "123", + }, + body: jsonBody(t, initResp), + }, + {"POST", "123", notificationInitialized, ""}: { + status: http.StatusAccepted, + wantProtocolVersion: latestProtocolVersion, + }, + getRequestKey: { + header: header{ + "Content-Type": "text/event-stream", + }, + wantProtocolVersion: latestProtocolVersion, + optional: !test.expectGETRequest, + }, + {"DELETE", "123", "", ""}: { + optional: true, + }, + }, + } + + httpServer := httptest.NewServer(fake) + defer httpServer.Close() + + transport := &StreamableClientTransport{ + Endpoint: httpServer.URL, + DisableListening: test.disableListening, + } + client := NewClient(testImpl, nil) + session, err := client.Connect(ctx, transport, nil) + if err != nil { + t.Fatalf("client.Connect() failed: %v", err) + } + + // Give some time for the standalone SSE connection to be established (if enabled) + time.Sleep(100 * time.Millisecond) + + // Verify the connection state + streamableConn, ok := session.mcpConn.(*streamableClientConn) + if !ok { + t.Fatalf("Expected *streamableClientConn, got %T", session.mcpConn) + } + + if got, want := streamableConn.disableListening, test.disableListening; got != want { + t.Errorf("disableListening field: got %v, want %v", got, want) + } + + // Clean up + if err := session.Close(); err != nil { + t.Errorf("closing session: %v", err) + } + + // Check if GET request was received + fake.calledMu.Lock() + getRequestReceived := false + if fake.called != nil { + getRequestReceived = fake.called[getRequestKey] + } + fake.calledMu.Unlock() + + if got, want := getRequestReceived, test.expectGETRequest; got != want { + t.Errorf("GET request received: got %v, want %v", got, want) + } + + // If we expected a GET request, verify it was actually received + if test.expectGETRequest { + if missing := fake.missingRequests(); len(missing) > 0 { + // Filter out optional requests + var requiredMissing []streamableRequestKey + for _, key := range missing { + if resp, ok := fake.responses[key]; ok && !resp.optional { + requiredMissing = append(requiredMissing, key) + } + } + if len(requiredMissing) > 0 { + t.Errorf("did not receive expected requests: %v", requiredMissing) + } + } + } else { + // If we didn't expect a GET request, verify it wasn't sent + if getRequestReceived { + t.Error("GET request was sent unexpectedly when DisableListening is true") + } + } + }) + } +}