diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 34671c105..f4cd57be6 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -13,6 +13,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; +import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -316,33 +317,36 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } else { // Establish new listening stream - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session - .listeningStream(sessionTransport); - - asyncContext.addListener(new jakarta.servlet.AsyncListener() { - @Override - public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection completed for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection timed out for session: {}", sessionId); - listeningStream.close(); - } + session.listeningStream(sessionTransport) + .doOnNext(serverSessionStream -> asyncContext.addListener(new jakarta.servlet.AsyncListener() { + @Override + public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection completed for session: {}", sessionId); + serverSessionStream.close(); + } + + @Override + public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection timed out for session: {}", sessionId); + serverSessionStream.close(); + } + + @Override + public void onError(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection error for session: {}", sessionId); + serverSessionStream.close(); + } + + @Override + public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException { + // No action needed + } + })) + .doOnError(error -> { + logger.error("Failed to create listening stream", error); + }) + .subscribe(serverSessionStream -> logger.debug("Listening stream created successfully")); - @Override - public void onError(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection error for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException { - // No action needed - } - }); } } catch (Exception e) { diff --git a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index ec03dd424..bf50d8c0e 100644 --- a/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp-core/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -138,10 +138,16 @@ public Mono delete() { * @param transport The dedicated SSE transport stream * @return a stream representation */ - public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) { + public Mono listeningStream(McpStreamableServerTransport transport) { McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport); - this.listeningStreamRef.set(listeningStream); - return listeningStream; + McpLoggableSession oldStream = this.listeningStreamRef.getAndSet(listeningStream); + if (oldStream != null && !(oldStream instanceof MissingMcpTransportSession)) { + logger.debug( + "Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream", + this.id); + return oldStream.closeGracefully().thenReturn(listeningStream); + } + return Mono.just(listeningStream); } // TODO: keep track of history by keeping a map from eventId to stream and then diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index 144a3ce02..38d587090 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -201,9 +201,16 @@ private Mono handleGet(ServerRequest request) { .body(Flux.>create(sink -> { WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport( sink); - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session - .listeningStream(sessionTransport); - sink.onDispose(listeningStream::close); + session.listeningStream(sessionTransport) + .doOnNext(serverSessionStream -> sink + .onDispose(() -> serverSessionStream.closeGracefully().subscribe(v -> { + }, error -> logger.warn("Failed to close listening stream gracefully", error)))) + .doOnError(error -> { + logger.error("Failed to create listening stream", error); + sink.error(error); + }) + .subscribe(serverSessionStream -> logger.debug("Listening stream created successfully"), + sink::error); // TODO Clarify why the outer context is not present in the // Flux.create sink? }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)), ServerSentEvent.class); diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java index d85046a67..f0754079a 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -289,13 +289,15 @@ private ServerResponse handleGet(ServerRequest request) { } else { // Establish new listening stream - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + Mono listeningStream = session .listeningStream(sessionTransport); - - sseBuilder.onComplete(() -> { + listeningStream.subscribe(serverSessionStream -> sseBuilder.onComplete(() -> { logger.debug("SSE connection completed for session: {}", sessionId); - listeningStream.close(); - }); + serverSessionStream.close(); + }), error -> { + sseBuilder.error(error); + logger.error("Failed to create listening stream", error); + }, () -> logger.debug("Listening stream created successfully")); } }, Duration.ZERO); } diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java index cb7b4a2a0..59c7ac8ed 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -85,17 +85,24 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - clientBuilders - .put("httpclient", - McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) - .endpoint(MESSAGE_ENDPOINT) - .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); - - clientBuilders.put("webflux", - McpClient.sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) - .endpoint(MESSAGE_ENDPOINT) - .build())); + var httpClientTransport = HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(MESSAGE_ENDPOINT) + .openConnectionOnStartup(true) + .build(); + + clientTransportBuilders.put("httpclient", httpClientTransport); + clientBuilders.put("httpclient", + McpClient.sync(httpClientTransport) + .initializationTimeout(Duration.ofHours(10)) + .requestTimeout(Duration.ofHours(10))); + var webClientTransport = WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(MESSAGE_ENDPOINT) + .openConnectionOnStartup(true) + .build(); + clientTransportBuilders.put("webflux", webClientTransport); + + clientBuilders.put("webflux", McpClient.sync(webClientTransport)); // Get the transport from Spring context this.mcpServerTransportProvider = tomcatServer.appContext() diff --git a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml index d4ccbc173..29d4a4850 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml @@ -18,7 +18,7 @@ - + diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java index 84bd271a5..3eaa13cee 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -4,6 +4,9 @@ package io.modelcontextprotocol; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -27,6 +30,7 @@ import io.modelcontextprotocol.server.McpServerFeatures; import io.modelcontextprotocol.server.McpSyncServer; import io.modelcontextprotocol.server.McpSyncServerExchange; +import io.modelcontextprotocol.spec.McpClientTransport; import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpSchema.CallToolResult; @@ -68,6 +72,8 @@ public abstract class AbstractMcpClientServerIntegrationTests { protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + protected ConcurrentHashMap clientTransportBuilders = new ConcurrentHashMap<>(); + abstract protected void prepareClients(int port, String mcpEndpoint); abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); @@ -1015,6 +1021,32 @@ void testInitialize(String clientType) { } } + @ParameterizedTest(name = "{0} : {displayName} ") + @MethodSource("clientsForTesting") + void testListeningStreamWillClosedWhenNew(String clientType) throws IOException { + var clientTransport = clientTransportBuilders.get(clientType); + if (clientTransport == null) { + return; + } + PrintStream originalOut = System.out; + ByteArrayOutputStream capturedOutput = new ByteArrayOutputStream(); + System.setOut(new PrintStream(capturedOutput)); + + var clientBuilder = clientBuilders.get(clientType); + var mcpServer = prepareSyncServerBuilder().build(); + var mcpClient = clientBuilder.build(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + clientTransport.connect(message -> Mono.empty()).subscribe(); + await().atMost(Duration.ofSeconds(1)).untilAsserted(() -> { + assertThat(capturedOutput.toString().contains("Listening stream already exists for this session")).isTrue(); + }); + System.setOut(originalOut); + capturedOutput.close(); + mcpClient.close(); + mcpServer.close(); + } + // --------------------------------------- // Logging Tests // --------------------------------------- @@ -1438,7 +1470,7 @@ void testStructuredOutputOfObjectArrayValidationSuccess(String clientType) { "type", "object", "properties", Map.of( "name", Map.of("type", "string"), - "age", Map.of("type", "number")), + "age", Map.of("type", "number")), "required", List.of("name", "age"))); // @formatter:on Tool calculatorTool = Tool.builder()