Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 56 additions & 13 deletions mcp/streamable.go
Original file line number Diff line number Diff line change
Expand Up @@ -1389,6 +1389,20 @@ type StreamableClientTransport struct {
// It defaults to 5. To disable retries, use a negative number.
MaxRetries int

// DisableListening disables receiving server-to-client notifications when no request is in flight.
//
// By default, the client establishes a standalone long-live GET HTTP connection to the server
// to receive server-initiated messages (like ToolListChangedNotification).
// https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#listening-for-messages-from-the-server
// NOTICE: Even if continuous listening is enabled, the server may not support this feature.
//
// If false (default), the client will establish the standalone SSE stream.
// If true, the client will not establish the standalone SSE stream and will only receive
// responses to its own requests.
//
// Defaults to false to maintain backward compatibility with existing behavior.
DisableListening bool

// TODO(rfindley): propose exporting these.
// If strict is set, the transport is in 'strict mode', where any violation
// of the MCP spec causes a failure.
Expand Down Expand Up @@ -1416,6 +1430,28 @@ var (
reconnectInitialDelay = 1 * time.Second
)

// WithDisableListening disables receiving server-to-client notifications when no request is in flight.
//
// By default, the client establishes a standalone long-live GET HTTP connection to the server
// to receive server-initiated messages. This function disables that behavior.
//
// If you want to disable continuous listening, you can either:
//
// transport := &mcp.StreamableClientTransport{
// Endpoint: "http://localhost:8080/mcp",
// DisableListening: true,
// }
//
// Or use this convenience function:
//
// transport := &mcp.StreamableClientTransport{
// Endpoint: "http://localhost:8080/mcp",
// }
// mcp.WithDisableListening(transport)
func WithDisableListening(transport *StreamableClientTransport) {
transport.DisableListening = true
}

// Connect implements the [Transport] interface.
//
// The resulting [Connection] writes messages via POST requests to the
Expand Down Expand Up @@ -1453,16 +1489,17 @@ func (t *StreamableClientTransport) Connect(ctx context.Context) (Connection, er
// middleware), yet only cancel the standalone stream when the connection is closed.
connCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
conn := &streamableClientConn{
url: t.Endpoint,
client: client,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
maxRetries: maxRetries,
strict: t.strict,
logger: ensureLogger(t.logger), // must be non-nil for safe logging
ctx: connCtx,
cancel: cancel,
failed: make(chan struct{}),
url: t.Endpoint,
client: client,
incoming: make(chan jsonrpc.Message, 10),
done: make(chan struct{}),
maxRetries: maxRetries,
strict: t.strict,
logger: ensureLogger(t.logger), // must be non-nil for safe logging
ctx: connCtx,
cancel: cancel,
failed: make(chan struct{}),
disableListening: t.DisableListening,
}
return conn, nil
}
Expand All @@ -1477,6 +1514,10 @@ type streamableClientConn struct {
strict bool // from [StreamableClientTransport.strict]
logger *slog.Logger // from [StreamableClientTransport.logger]

// disableListening controls whether to disable the standalone SSE stream
// for receiving server-to-client notifications when no request is in flight.
disableListening bool // from [StreamableClientTransport.DisableListening]

// Guard calls to Close, as it may be called multiple times.
closeOnce sync.Once
closeErr error
Expand Down Expand Up @@ -1518,7 +1559,7 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
c.mu.Unlock()

// Start the standalone SSE stream as soon as we have the initialized
// result.
// result, if continuous listening is enabled.
//
// § 2.2: The client MAY issue an HTTP GET to the MCP endpoint. This can be
// used to open an SSE stream, allowing the server to communicate to the
Expand All @@ -1528,9 +1569,11 @@ func (c *streamableClientConn) sessionUpdated(state clientSessionState) {
// initialized, we don't know whether the server requires a sessionID.
//
// § 2.5: A server using the Streamable HTTP transport MAY assign a session
// ID at initialization time, by including it in an Mcp-Session-Id header
// ID at initialization time, by including it in a Mcp-Session-Id header
// on the HTTP response containing the InitializeResult.
c.connectStandaloneSSE()
if !c.disableListening {
c.connectStandaloneSSE()
}
}

func (c *streamableClientConn) connectStandaloneSSE() {
Expand Down
118 changes: 118 additions & 0 deletions mcp/streamable_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -693,3 +693,121 @@ func TestStreamableClientTransientErrors(t *testing.T) {
})
}
}

func TestStreamableClientDisableListening(t *testing.T) {
ctx := context.Background()

tests := []struct {
name string
disableListening bool
expectGETRequest bool
}{
{
name: "default behavior (listening enabled)",
disableListening: false,
expectGETRequest: true,
},
{
name: "listening disabled",
disableListening: true,
expectGETRequest: false,
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
getRequestKey := streamableRequestKey{"GET", "123", "", ""}

fake := &fakeStreamableServer{
t: t,
responses: fakeResponses{
{"POST", "", methodInitialize, ""}: {
header: header{
"Content-Type": "application/json",
sessionIDHeader: "123",
},
body: jsonBody(t, initResp),
},
{"POST", "123", notificationInitialized, ""}: {
status: http.StatusAccepted,
wantProtocolVersion: latestProtocolVersion,
},
getRequestKey: {
header: header{
"Content-Type": "text/event-stream",
},
wantProtocolVersion: latestProtocolVersion,
optional: !test.expectGETRequest,
},
{"DELETE", "123", "", ""}: {
optional: true,
},
},
}

httpServer := httptest.NewServer(fake)
defer httpServer.Close()

transport := &StreamableClientTransport{
Endpoint: httpServer.URL,
DisableListening: test.disableListening,
}
client := NewClient(testImpl, nil)
session, err := client.Connect(ctx, transport, nil)
if err != nil {
t.Fatalf("client.Connect() failed: %v", err)
}

// Give some time for the standalone SSE connection to be established (if enabled)
time.Sleep(100 * time.Millisecond)

// Verify the connection state
streamableConn, ok := session.mcpConn.(*streamableClientConn)
if !ok {
t.Fatalf("Expected *streamableClientConn, got %T", session.mcpConn)
}

if got, want := streamableConn.disableListening, test.disableListening; got != want {
t.Errorf("disableListening field: got %v, want %v", got, want)
}

// Clean up
if err := session.Close(); err != nil {
t.Errorf("closing session: %v", err)
}

// Check if GET request was received
fake.calledMu.Lock()
getRequestReceived := false
if fake.called != nil {
getRequestReceived = fake.called[getRequestKey]
}
fake.calledMu.Unlock()

if got, want := getRequestReceived, test.expectGETRequest; got != want {
t.Errorf("GET request received: got %v, want %v", got, want)
}

// If we expected a GET request, verify it was actually received
if test.expectGETRequest {
if missing := fake.missingRequests(); len(missing) > 0 {
// Filter out optional requests
var requiredMissing []streamableRequestKey
for _, key := range missing {
if resp, ok := fake.responses[key]; ok && !resp.optional {
requiredMissing = append(requiredMissing, key)
}
}
if len(requiredMissing) > 0 {
t.Errorf("did not receive expected requests: %v", requiredMissing)
}
}
} else {
// If we didn't expect a GET request, verify it wasn't sent
if getRequestReceived {
t.Error("GET request was sent unexpectedly when DisableListening is true")
}
}
})
}
}