@@ -10,14 +10,15 @@ use crate::schema::{
10
10
} ,
11
11
InitializeRequestParams , InitializeResult , RequestId , RpcError ,
12
12
} ;
13
+ use crate :: utils:: AbortTaskOnDrop ;
13
14
use async_trait:: async_trait;
14
15
use futures:: future:: try_join_all;
15
16
use futures:: { StreamExt , TryFutureExt } ;
16
17
#[ cfg( feature = "hyper-server" ) ]
17
18
use rust_mcp_transport:: SessionId ;
18
19
use rust_mcp_transport:: { IoStream , TransportDispatcher } ;
19
20
use std:: collections:: HashMap ;
20
- use std:: sync:: { Arc , RwLock } ;
21
+ use std:: sync:: Arc ;
21
22
use std:: time:: Duration ;
22
23
use tokio:: io:: AsyncWriteExt ;
23
24
use tokio:: sync:: { oneshot, watch} ;
@@ -41,8 +42,6 @@ pub struct ServerRuntime {
41
42
handler : Arc < dyn McpServerHandler > ,
42
43
// Information about the server
43
44
server_details : Arc < InitializeResult > ,
44
- // Details about the connected client
45
- client_details : Arc < RwLock < Option < InitializeRequestParams > > > ,
46
45
#[ cfg( feature = "hyper-server" ) ]
47
46
session_id : Option < SessionId > ,
48
47
transport_map : tokio:: sync:: RwLock < HashMap < String , TransportType > > ,
@@ -123,12 +122,7 @@ impl McpServer for ServerRuntime {
123
122
124
123
/// Returns the client information if available, after successful initialization , otherwise returns None
125
124
fn client_info ( & self ) -> Option < InitializeRequestParams > {
126
- if let Ok ( details) = self . client_details . read ( ) {
127
- details. clone ( )
128
- } else {
129
- // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
130
- None
131
- }
125
+ self . client_details_rx . borrow ( ) . clone ( )
132
126
}
133
127
134
128
/// Main runtime loop, processes incoming messages and handles requests
@@ -404,6 +398,11 @@ impl ServerRuntime {
404
398
. await ?
405
399
. abort_handle ( ) ;
406
400
401
+ // ensure keep_alive task will be aborted
402
+ let _abort_guard = AbortTaskOnDrop {
403
+ handle : abort_alive_task,
404
+ } ;
405
+
407
406
// in case there is a payload, we consume it by transport to get processed
408
407
if let Some ( payload) = payload {
409
408
transport. consume_string_payload ( & payload) . await ?;
@@ -439,13 +438,11 @@ impl ServerRuntime {
439
438
}
440
439
// close the stream after all messages are sent, unless it is a standalone stream
441
440
if !stream_id. eq( DEFAULT_STREAM_ID ) {
442
- abort_alive_task. abort( ) ;
443
441
return Ok ( ( ) ) ;
444
442
}
445
443
}
446
444
_ = & mut disconnect_rx => {
447
445
self . remove_transport( stream_id) . await ?;
448
- abort_alive_task. abort( ) ;
449
446
// Disconnection detected by keep-alive task
450
447
return Err ( SdkError :: connection_closed( ) . into( ) ) ;
451
448
@@ -469,7 +466,6 @@ impl ServerRuntime {
469
466
watch:: channel :: < Option < InitializeRequestParams > > ( None ) ;
470
467
Self {
471
468
server_details,
472
- client_details : Arc :: new ( RwLock :: new ( None ) ) ,
473
469
handler,
474
470
session_id : Some ( session_id) ,
475
471
transport_map : tokio:: sync:: RwLock :: new ( HashMap :: new ( ) ) ,
@@ -495,7 +491,6 @@ impl ServerRuntime {
495
491
watch:: channel :: < Option < InitializeRequestParams > > ( None ) ;
496
492
Self {
497
493
server_details : Arc :: new ( server_details) ,
498
- client_details : Arc :: new ( RwLock :: new ( None ) ) ,
499
494
handler,
500
495
#[ cfg( feature = "hyper-server" ) ]
501
496
session_id : None ,
0 commit comments