From 3dac3075f898ef23cbcdeb2de1f70008555e8ab7 Mon Sep 17 00:00:00 2001 From: taobaorun Date: Thu, 28 Aug 2025 13:15:41 +0800 Subject: [PATCH 1/5] avoid streamable listening sse duplicate creation --- .../WebFluxStreamableServerTransportProvider.java | 12 ++++++++++-- .../WebMvcStreamableServerTransportProvider.java | 12 +++++++++--- ...HttpServletStreamableServerTransportProvider.java | 11 ++++++++++- .../spec/McpStreamableServerSession.java | 4 ++++ 4 files changed, 33 insertions(+), 6 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index f3f6c2c33..b67efad64 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -10,8 +10,10 @@ import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpError; +import io.modelcontextprotocol.spec.McpLoggableSession; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import io.modelcontextprotocol.spec.McpStreamableServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.spec.ProtocolVersions; @@ -187,12 +189,18 @@ private Mono handleGet(ServerRequest request) { return ServerResponse.notFound().build(); } - if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { + McpLoggableSession listenedStream = session.getListeningStream(); + boolean replayRequest = request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID); + if (replayRequest) { String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) .body(session.replay(lastId), ServerSentEvent.class); } + if (listenedStream instanceof McpStreamableServerSessionStream) { + logger.debug("Listening stream for session: {} exists.", sessionId); + return ServerResponse.ok().build(); + } return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) @@ -484,4 +492,4 @@ public WebFluxStreamableServerTransportProvider build() { } -} \ No newline at end of file +} diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java index fa51a0130..f1f097fea 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -10,6 +10,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; +import io.modelcontextprotocol.spec.McpLoggableSession; +import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.HttpStatus; @@ -252,7 +254,12 @@ private ServerResponse handleGet(ServerRequest request) { } logger.debug("Handling GET request for session: {}", sessionId); - + McpLoggableSession listenedStream = session.getListeningStream(); + boolean replayRequest = request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID); + if (!replayRequest && listenedStream instanceof McpStreamableServerSessionStream) { + logger.debug("Listening stream for session: {} exists.", sessionId); + return ServerResponse.ok().build(); + } try { return ServerResponse.sse(sseBuilder -> { sseBuilder.onTimeout(() -> { @@ -263,9 +270,8 @@ private ServerResponse handleGet(ServerRequest request) { sessionId, sseBuilder); // Check if this is a replay request - if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { + if (replayRequest) { String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); - try { session.replay(lastId) .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 8b95ec607..564e762b2 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -13,6 +13,8 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; +import io.modelcontextprotocol.spec.McpLoggableSession; +import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -273,6 +275,13 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } logger.debug("Handling GET request for session: {}", sessionId); + McpLoggableSession listenedStream = session.getListeningStream(); + boolean replayRequest = request.getHeader(HttpHeaders.LAST_EVENT_ID) != null; + if (!replayRequest && listenedStream instanceof McpStreamableServerSessionStream) { + logger.debug("Listening stream for session: {} exists.", sessionId); + response.setStatus(HttpServletResponse.SC_OK); + return; + } McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); @@ -290,7 +299,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) sessionId, asyncContext, response.getWriter()); // Check if this is a replay request - if (request.getHeader(HttpHeaders.LAST_EVENT_ID) != null) { + if (replayRequest) { String lastId = request.getHeader(HttpHeaders.LAST_EVENT_ID); try { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index ef7967c1e..1c3a1a702 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -142,6 +142,10 @@ public McpStreamableServerSessionStream listeningStream(McpStreamableServerTrans return listeningStream; } + public McpLoggableSession getListeningStream() { + return this.listeningStreamRef.get(); + } + // TODO: keep track of history by keeping a map from eventId to stream and then // iterate over the events using the lastEventId public Flux replay(Object lastEventId) { From c7cbe98208303625c3f1a72b348d73d23ad0810b Mon Sep 17 00:00:00 2001 From: taobaorun Date: Sun, 31 Aug 2025 16:58:24 +0800 Subject: [PATCH 2/5] Listening stream already exists for this session and will be closed to make way for the new listening SSE stream --- ...ebFluxStreamableServerTransportProvider.java | 9 +++++---- ...WebMvcStreamableServerTransportProvider.java | 15 ++++++++------- ...ervletStreamableServerTransportProvider.java | 17 ++++++++--------- 3 files changed, 21 insertions(+), 20 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index b67efad64..0a9eb452b 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -190,16 +190,17 @@ private Mono handleGet(ServerRequest request) { } McpLoggableSession listenedStream = session.getListeningStream(); - boolean replayRequest = request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID); - if (replayRequest) { + if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) .body(session.replay(lastId), ServerSentEvent.class); } if (listenedStream instanceof McpStreamableServerSessionStream) { - logger.debug("Listening stream for session: {} exists.", sessionId); - return ServerResponse.ok().build(); + logger.debug( + "Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream", + sessionId); + listenedStream.close(); } return ServerResponse.ok() diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java index f1f097fea..138b46e92 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -254,12 +254,6 @@ private ServerResponse handleGet(ServerRequest request) { } logger.debug("Handling GET request for session: {}", sessionId); - McpLoggableSession listenedStream = session.getListeningStream(); - boolean replayRequest = request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID); - if (!replayRequest && listenedStream instanceof McpStreamableServerSessionStream) { - logger.debug("Listening stream for session: {} exists.", sessionId); - return ServerResponse.ok().build(); - } try { return ServerResponse.sse(sseBuilder -> { sseBuilder.onTimeout(() -> { @@ -270,7 +264,7 @@ private ServerResponse handleGet(ServerRequest request) { sessionId, sseBuilder); // Check if this is a replay request - if (replayRequest) { + if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); try { session.replay(lastId) @@ -294,6 +288,13 @@ private ServerResponse handleGet(ServerRequest request) { } } else { + McpLoggableSession listenedStream = session.getListeningStream(); + if (listenedStream instanceof McpStreamableServerSessionStream) { + logger.debug( + "Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream", + sessionId); + listenedStream.close(); + } // Establish new listening stream McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session .listeningStream(sessionTransport); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 564e762b2..bbc26a31c 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -275,14 +275,6 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } logger.debug("Handling GET request for session: {}", sessionId); - McpLoggableSession listenedStream = session.getListeningStream(); - boolean replayRequest = request.getHeader(HttpHeaders.LAST_EVENT_ID) != null; - if (!replayRequest && listenedStream instanceof McpStreamableServerSessionStream) { - logger.debug("Listening stream for session: {} exists.", sessionId); - response.setStatus(HttpServletResponse.SC_OK); - return; - } - McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); try { @@ -299,7 +291,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) sessionId, asyncContext, response.getWriter()); // Check if this is a replay request - if (replayRequest) { + if (request.getHeader(HttpHeaders.LAST_EVENT_ID) != null) { String lastId = request.getHeader(HttpHeaders.LAST_EVENT_ID); try { @@ -324,6 +316,13 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } } else { + McpLoggableSession listenedStream = session.getListeningStream(); + if (listenedStream instanceof McpStreamableServerSessionStream) { + logger.debug( + "Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream", + sessionId); + listenedStream.close(); + } // Establish new listening stream McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session .listeningStream(sessionTransport); From e87387b6aec9099090e4b96f09a3b406a476f2a8 Mon Sep 17 00:00:00 2001 From: taobaorun Date: Tue, 2 Sep 2025 01:02:23 +0800 Subject: [PATCH 3/5] Atomically close the existing listening stream and switch to the new one. --- ...FluxStreamableServerTransportProvider.java | 9 ----- ...bMvcStreamableServerTransportProvider.java | 11 ++---- ...vletStreamableServerTransportProvider.java | 10 +----- .../spec/McpStreamableServerSession.java | 34 +++++++++---------- 4 files changed, 20 insertions(+), 44 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index 0a9eb452b..b2b44f5cd 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -10,10 +10,8 @@ import io.modelcontextprotocol.server.McpTransportContextExtractor; import io.modelcontextprotocol.spec.HttpHeaders; import io.modelcontextprotocol.spec.McpError; -import io.modelcontextprotocol.spec.McpLoggableSession; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpStreamableServerSession; -import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import io.modelcontextprotocol.spec.McpStreamableServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.spec.ProtocolVersions; @@ -189,19 +187,12 @@ private Mono handleGet(ServerRequest request) { return ServerResponse.notFound().build(); } - McpLoggableSession listenedStream = session.getListeningStream(); if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) .body(session.replay(lastId), ServerSentEvent.class); } - if (listenedStream instanceof McpStreamableServerSessionStream) { - logger.debug( - "Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream", - sessionId); - listenedStream.close(); - } return ServerResponse.ok() .contentType(MediaType.TEXT_EVENT_STREAM) diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java index 138b46e92..fa51a0130 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -10,8 +10,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; -import io.modelcontextprotocol.spec.McpLoggableSession; -import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.HttpStatus; @@ -254,6 +252,7 @@ private ServerResponse handleGet(ServerRequest request) { } logger.debug("Handling GET request for session: {}", sessionId); + try { return ServerResponse.sse(sseBuilder -> { sseBuilder.onTimeout(() -> { @@ -266,6 +265,7 @@ private ServerResponse handleGet(ServerRequest request) { // Check if this is a replay request if (request.headers().asHttpHeaders().containsKey(HttpHeaders.LAST_EVENT_ID)) { String lastId = request.headers().asHttpHeaders().getFirst(HttpHeaders.LAST_EVENT_ID); + try { session.replay(lastId) .contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)) @@ -288,13 +288,6 @@ private ServerResponse handleGet(ServerRequest request) { } } else { - McpLoggableSession listenedStream = session.getListeningStream(); - if (listenedStream instanceof McpStreamableServerSessionStream) { - logger.debug( - "Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream", - sessionId); - listenedStream.close(); - } // Establish new listening stream McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session .listeningStream(sessionTransport); diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index bbc26a31c..8b95ec607 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -13,8 +13,6 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; -import io.modelcontextprotocol.spec.McpLoggableSession; -import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -275,6 +273,7 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } logger.debug("Handling GET request for session: {}", sessionId); + McpTransportContext transportContext = this.contextExtractor.extract(request, new DefaultMcpTransportContext()); try { @@ -316,13 +315,6 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } } else { - McpLoggableSession listenedStream = session.getListeningStream(); - if (listenedStream instanceof McpStreamableServerSessionStream) { - logger.debug( - "Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream", - sessionId); - listenedStream.close(); - } // Establish new listening stream McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session .listeningStream(sessionTransport); diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index 1c3a1a702..4509f8337 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -4,28 +4,26 @@ package io.modelcontextprotocol.spec; -import java.time.Duration; -import java.util.Map; -import java.util.UUID; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicLong; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Supplier; - -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - import com.fasterxml.jackson.core.type.TypeReference; - import io.modelcontextprotocol.server.McpAsyncServerExchange; import io.modelcontextprotocol.server.McpNotificationHandler; import io.modelcontextprotocol.server.McpRequestHandler; import io.modelcontextprotocol.server.McpTransportContext; import io.modelcontextprotocol.util.Assert; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoSink; +import java.time.Duration; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + /** * Representation of a Streamable HTTP server session that keeps track of mapping * server-initiated requests to the client and mapping arriving responses. It also allows @@ -138,14 +136,16 @@ public Mono delete() { */ public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) { McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport); - this.listeningStreamRef.set(listeningStream); + McpLoggableSession listenedStream = this.listeningStreamRef.getAndSet(listeningStream); + if (listenedStream != null) { + logger.debug( + "Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream", + this.id); + listenedStream.closeGracefully().block(); + } return listeningStream; } - public McpLoggableSession getListeningStream() { - return this.listeningStreamRef.get(); - } - // TODO: keep track of history by keeping a map from eventId to stream and then // iterate over the events using the lastEventId public Flux replay(Object lastEventId) { From 4b30e0a937e9354777c6e44447b0f94a08fccaeb Mon Sep 17 00:00:00 2001 From: taobaorun Date: Wed, 3 Sep 2025 11:35:51 +0800 Subject: [PATCH 4/5] fix listening sse stream close blocking --- ...FluxStreamableServerTransportProvider.java | 14 ++++- ...bMvcStreamableServerTransportProvider.java | 13 +++-- ...vletStreamableServerTransportProvider.java | 56 ++++++++++--------- .../spec/McpStreamableServerSession.java | 10 ++-- 4 files changed, 54 insertions(+), 39 deletions(-) diff --git a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java index b2b44f5cd..1e14d6fd5 100644 --- a/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webflux/src/main/java/io/modelcontextprotocol/server/transport/WebFluxStreamableServerTransportProvider.java @@ -12,6 +12,7 @@ import io.modelcontextprotocol.spec.McpError; import io.modelcontextprotocol.spec.McpSchema; import io.modelcontextprotocol.spec.McpStreamableServerSession; +import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import io.modelcontextprotocol.spec.McpStreamableServerTransport; import io.modelcontextprotocol.spec.McpStreamableServerTransportProvider; import io.modelcontextprotocol.spec.ProtocolVersions; @@ -199,9 +200,16 @@ private Mono handleGet(ServerRequest request) { .body(Flux.>create(sink -> { WebFluxStreamableMcpSessionTransport sessionTransport = new WebFluxStreamableMcpSessionTransport( sink); - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session - .listeningStream(sessionTransport); - sink.onDispose(listeningStream::close); + session.listeningStream(sessionTransport) + .doOnNext(serverSessionStream -> sink + .onDispose(() -> serverSessionStream.closeGracefully().subscribe(v -> { + }, error -> logger.warn("Failed to close listening stream gracefully", error)))) + .doOnError(error -> { + logger.error("Failed to create listening stream", error); + sink.error(error); + }) + .subscribe(serverSessionStream -> logger.debug("Listening stream created successfully"), + sink::error); }), ServerSentEvent.class); }).contextWrite(ctx -> ctx.put(McpTransportContext.KEY, transportContext)); diff --git a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java index fa51a0130..ca67277a5 100644 --- a/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java +++ b/mcp-spring/mcp-spring-webmvc/src/main/java/io/modelcontextprotocol/server/transport/WebMvcStreamableServerTransportProvider.java @@ -10,6 +10,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; +import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.http.HttpStatus; @@ -289,13 +290,15 @@ private ServerResponse handleGet(ServerRequest request) { } else { // Establish new listening stream - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session + Mono listeningStream = session .listeningStream(sessionTransport); - - sseBuilder.onComplete(() -> { + listeningStream.subscribe(serverSessionStream -> sseBuilder.onComplete(() -> { logger.debug("SSE connection completed for session: {}", sessionId); - listeningStream.close(); - }); + serverSessionStream.close(); + }), error -> { + sseBuilder.error(error); + logger.error("Failed to create listening stream", error); + }, () -> logger.debug("Listening stream created successfully")); } }, Duration.ZERO); } diff --git a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java index 8b95ec607..9aaa64b54 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java +++ b/mcp/src/main/java/io/modelcontextprotocol/server/transport/HttpServletStreamableServerTransportProvider.java @@ -13,6 +13,7 @@ import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.locks.ReentrantLock; +import io.modelcontextprotocol.spec.McpStreamableServerSession.McpStreamableServerSessionStream; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -316,33 +317,36 @@ protected void doGet(HttpServletRequest request, HttpServletResponse response) } else { // Establish new listening stream - McpStreamableServerSession.McpStreamableServerSessionStream listeningStream = session - .listeningStream(sessionTransport); - - asyncContext.addListener(new jakarta.servlet.AsyncListener() { - @Override - public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection completed for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection timed out for session: {}", sessionId); - listeningStream.close(); - } + session.listeningStream(sessionTransport) + .doOnNext(serverSessionStream -> asyncContext.addListener(new jakarta.servlet.AsyncListener() { + @Override + public void onComplete(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection completed for session: {}", sessionId); + serverSessionStream.close(); + } + + @Override + public void onTimeout(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection timed out for session: {}", sessionId); + serverSessionStream.close(); + } + + @Override + public void onError(jakarta.servlet.AsyncEvent event) throws IOException { + logger.debug("SSE connection error for session: {}", sessionId); + serverSessionStream.close(); + } + + @Override + public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException { + // No action needed + } + })) + .doOnError(error -> { + logger.error("Failed to create listening stream", error); + }) + .subscribe(serverSessionStream -> logger.debug("Listening stream created successfully")); - @Override - public void onError(jakarta.servlet.AsyncEvent event) throws IOException { - logger.debug("SSE connection error for session: {}", sessionId); - listeningStream.close(); - } - - @Override - public void onStartAsync(jakarta.servlet.AsyncEvent event) throws IOException { - // No action needed - } - }); } } catch (Exception e) { diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index 4509f8337..dc931edac 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -134,16 +134,16 @@ public Mono delete() { * @param transport The dedicated SSE transport stream * @return a stream representation */ - public McpStreamableServerSessionStream listeningStream(McpStreamableServerTransport transport) { + public Mono listeningStream(McpStreamableServerTransport transport) { McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport); - McpLoggableSession listenedStream = this.listeningStreamRef.getAndSet(listeningStream); - if (listenedStream != null) { + McpLoggableSession oldStream = this.listeningStreamRef.getAndSet(listeningStream); + if (oldStream != null) { logger.debug( "Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream", this.id); - listenedStream.closeGracefully().block(); + return oldStream.closeGracefully().thenReturn(listeningStream); } - return listeningStream; + return Mono.just(listeningStream); } // TODO: keep track of history by keeping a map from eventId to stream and then From 19d02f23c0dc7e9d63b95c082a761b3cfddb8033 Mon Sep 17 00:00:00 2001 From: taobaorun Date: Fri, 5 Sep 2025 15:15:01 +0800 Subject: [PATCH 5/5] optimize and add test for listening sse closed --- .../WebMvcStreamableIntegrationTests.java | 29 ++++++++++------ .../src/test/resources/logback.xml | 2 +- ...stractMcpClientServerIntegrationTests.java | 34 ++++++++++++++++++- .../spec/McpStreamableServerSession.java | 2 +- 4 files changed, 53 insertions(+), 14 deletions(-) diff --git a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java index 3f1716f89..e4dca31b2 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java +++ b/mcp-spring/mcp-spring-webmvc/src/test/java/io/modelcontextprotocol/server/WebMvcStreamableIntegrationTests.java @@ -81,17 +81,24 @@ public void before() { throw new RuntimeException("Failed to start Tomcat", e); } - clientBuilders - .put("httpclient", - McpClient.sync(HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) - .endpoint(MESSAGE_ENDPOINT) - .build()).initializationTimeout(Duration.ofHours(10)).requestTimeout(Duration.ofHours(10))); - - clientBuilders.put("webflux", - McpClient.sync(WebClientStreamableHttpTransport - .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) - .endpoint(MESSAGE_ENDPOINT) - .build())); + var httpClientTransport = HttpClientStreamableHttpTransport.builder("http://localhost:" + PORT) + .endpoint(MESSAGE_ENDPOINT) + .openConnectionOnStartup(true) + .build(); + + clientTransportBuilders.put("httpclient", httpClientTransport); + clientBuilders.put("httpclient", + McpClient.sync(httpClientTransport) + .initializationTimeout(Duration.ofHours(10)) + .requestTimeout(Duration.ofHours(10))); + var webClientTransport = WebClientStreamableHttpTransport + .builder(WebClient.builder().baseUrl("http://localhost:" + PORT)) + .endpoint(MESSAGE_ENDPOINT) + .openConnectionOnStartup(true) + .build(); + clientTransportBuilders.put("webflux", webClientTransport); + + clientBuilders.put("webflux", McpClient.sync(webClientTransport)); // Get the transport from Spring context this.mcpServerTransportProvider = tomcatServer.appContext() diff --git a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml index d4ccbc173..29d4a4850 100644 --- a/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml +++ b/mcp-spring/mcp-spring-webmvc/src/test/resources/logback.xml @@ -18,7 +18,7 @@ - + diff --git a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java index 300f0b534..e9c3a0a59 100644 --- a/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java +++ b/mcp-test/src/main/java/io/modelcontextprotocol/AbstractMcpClientServerIntegrationTests.java @@ -12,6 +12,9 @@ import static org.awaitility.Awaitility.await; import static org.mockito.Mockito.mock; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; import java.net.URI; import java.net.http.HttpClient; import java.net.http.HttpRequest; @@ -29,6 +32,7 @@ import java.util.function.Function; import java.util.stream.Collectors; +import io.modelcontextprotocol.spec.McpClientTransport; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -66,6 +70,8 @@ public abstract class AbstractMcpClientServerIntegrationTests { protected ConcurrentHashMap clientBuilders = new ConcurrentHashMap<>(); + protected ConcurrentHashMap clientTransportBuilders = new ConcurrentHashMap<>(); + abstract protected void prepareClients(int port, String mcpEndpoint); abstract protected McpServer.AsyncSpecification prepareAsyncServerBuilder(); @@ -836,7 +842,7 @@ void testThrowingToolCallIsCaughtBeforeTimeout(String clientType) { @ParameterizedTest(name = "{0} : {displayName} ") @ValueSource(strings = { "httpclient", "webflux" }) - void testToolCallSuccessWithTranportContextExtraction(String clientType) { + void testToolCallSuccessWithTransportContextExtraction(String clientType) { var clientBuilder = clientBuilders.get(clientType); @@ -999,6 +1005,32 @@ void testInitialize(String clientType) { mcpServer.close(); } + @ParameterizedTest(name = "{0} : {displayName} ") + @ValueSource(strings = { "httpclient", "webflux" }) + void testListeningStreamWillClosedWhenNew(String clientType) throws IOException { + var clientTransport = clientTransportBuilders.get(clientType); + if (clientTransport == null) { + return; + } + PrintStream originalOut = System.out; + ByteArrayOutputStream capturedOutput = new ByteArrayOutputStream(); + System.setOut(new PrintStream(capturedOutput)); + + var clientBuilder = clientBuilders.get(clientType); + var mcpServer = prepareSyncServerBuilder().build(); + var mcpClient = clientBuilder.build(); + InitializeResult initResult = mcpClient.initialize(); + assertThat(initResult).isNotNull(); + clientTransport.connect(message -> Mono.empty()).subscribe(); + await().atMost(Duration.ofSeconds(1)).untilAsserted(() -> { + assertThat(capturedOutput.toString().contains("Listening stream already exists for this session")).isTrue(); + }); + System.setOut(originalOut); + capturedOutput.close(); + mcpClient.close(); + mcpServer.close(); + } + // --------------------------------------- // Logging Tests // --------------------------------------- diff --git a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java index 30882d91c..3296a00d1 100644 --- a/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java +++ b/mcp/src/main/java/io/modelcontextprotocol/spec/McpStreamableServerSession.java @@ -139,7 +139,7 @@ public Mono delete() { public Mono listeningStream(McpStreamableServerTransport transport) { McpStreamableServerSessionStream listeningStream = new McpStreamableServerSessionStream(transport); McpLoggableSession oldStream = this.listeningStreamRef.getAndSet(listeningStream); - if (oldStream != null) { + if (oldStream != null && !(oldStream instanceof MissingMcpTransportSession)) { logger.debug( "Listening stream already exists for this session:{} and will be closed to make way for the new listening SSE stream", this.id);