@@ -8,7 +8,6 @@ import io.ktor.server.response.respond
8
8
import io.ktor.server.routing.Routing
9
9
import io.ktor.server.routing.RoutingContext
10
10
import io.ktor.server.routing.post
11
- import io.ktor.server.routing.route
12
11
import io.ktor.server.routing.routing
13
12
import io.ktor.server.sse.SSE
14
13
import io.ktor.server.sse.ServerSSESession
@@ -18,39 +17,40 @@ import kotlinx.atomicfu.AtomicRef
18
17
import kotlinx.atomicfu.atomic
19
18
import kotlinx.atomicfu.update
20
19
import kotlinx.collections.immutable.PersistentMap
21
- import kotlinx.collections.immutable.persistentMapOf
20
+ import kotlinx.collections.immutable.toPersistentMap
22
21
23
22
private val logger = KotlinLogging .logger {}
24
23
25
- @KtorDsl
26
- public fun Routing.mcp (path : String , block : () -> Server ) {
27
- route(path) {
28
- mcp(block)
24
+ internal class SseTransportManager (transports : Map <String , SseServerTransport > = emptyMap()) {
25
+ private val transports: AtomicRef <PersistentMap <String , SseServerTransport >> = atomic(transports.toPersistentMap())
26
+
27
+ fun getTransport (sessionId : String ): SseServerTransport ? = transports.value[sessionId]
28
+
29
+ fun addTransport (transport : SseServerTransport ) {
30
+ transports.update { it.put(transport.sessionId, transport) }
31
+ }
32
+
33
+ fun removeTransport (sessionId : String ) {
34
+ transports.update { it.remove(sessionId) }
29
35
}
30
36
}
31
37
32
- /* *
33
- * Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE).
34
- */
38
+ /*
39
+ * Configures the Ktor Application to handle Model Context Protocol (MCP) over Server-Sent Events (SSE).
40
+ */
35
41
@KtorDsl
36
42
public fun Routing.mcp (block : () -> Server ) {
37
- val transports = atomic(persistentMapOf< String , SseServerTransport >() )
43
+ val sseTransportManager = SseTransportManager ( )
38
44
39
45
sse {
40
- mcpSseEndpoint(" " , transports , block)
46
+ mcpSseEndpoint(" " , sseTransportManager , block)
41
47
}
42
48
43
49
post {
44
- mcpPostEndpoint(transports )
50
+ mcpPostEndpoint(sseTransportManager )
45
51
}
46
52
}
47
53
48
- @Suppress(" FunctionName" )
49
- @Deprecated(" Use mcp() instead" , ReplaceWith (" mcp(block)" ), DeprecationLevel .WARNING )
50
- public fun Application.MCP (block : () -> Server ) {
51
- mcp(block)
52
- }
53
-
54
54
@KtorDsl
55
55
public fun Application.mcp (block : () -> Server ) {
56
56
install(SSE )
@@ -62,16 +62,16 @@ public fun Application.mcp(block: () -> Server) {
62
62
63
63
internal suspend fun ServerSSESession.mcpSseEndpoint (
64
64
postEndpoint : String ,
65
- transports : AtomicRef < PersistentMap < String , SseServerTransport >> ,
65
+ sseTransportManager : SseTransportManager ,
66
66
block : () -> Server ,
67
67
) {
68
- val transport = mcpSseTransport(postEndpoint, transports )
68
+ val transport = mcpSseTransport(postEndpoint, sseTransportManager )
69
69
70
70
val server = block()
71
71
72
72
server.onClose {
73
73
logger.info { " Server connection closed for sessionId: ${transport.sessionId} " }
74
- transports.update { it.remove (transport.sessionId) }
74
+ sseTransportManager.removeTransport (transport.sessionId)
75
75
}
76
76
77
77
server.connectSession(transport)
@@ -81,17 +81,17 @@ internal suspend fun ServerSSESession.mcpSseEndpoint(
81
81
82
82
internal fun ServerSSESession.mcpSseTransport (
83
83
postEndpoint : String ,
84
- transports : AtomicRef < PersistentMap < String , SseServerTransport >> ,
84
+ sseTransportManager : SseTransportManager ,
85
85
): SseServerTransport {
86
86
val transport = SseServerTransport (postEndpoint, this )
87
- transports.update { it.put (transport.sessionId, transport) }
87
+ sseTransportManager.addTransport (transport)
88
88
logger.info { " New SSE connection established and stored with sessionId: ${transport.sessionId} " }
89
89
90
90
return transport
91
91
}
92
92
93
93
internal suspend fun RoutingContext.mcpPostEndpoint (
94
- transports : AtomicRef < PersistentMap < String , SseServerTransport >> ,
94
+ sseTransportManager : SseTransportManager ,
95
95
) {
96
96
val sessionId: String = call.request.queryParameters[" sessionId" ]
97
97
? : run {
@@ -101,7 +101,7 @@ internal suspend fun RoutingContext.mcpPostEndpoint(
101
101
102
102
logger.debug { " Received message for sessionId: $sessionId " }
103
103
104
- val transport = transports.value[ sessionId]
104
+ val transport = sseTransportManager.getTransport( sessionId)
105
105
if (transport == null ) {
106
106
logger.warn { " Session not found for sessionId: $sessionId " }
107
107
call.respond(HttpStatusCode .NotFound , " Session not found" )
0 commit comments