Skip to content

Commit 308b1db

Browse files
authored
fix: handle missing client details and abort keep-alive task on drop (#83)
- Added guard (AbortTaskOnDrop) to ensure keep-alive task is aborted when no longer needed - Fixed bug where client_info was mistakenly returning None
1 parent aacdbfe commit 308b1db

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,15 @@ use crate::schema::{
1010
},
1111
InitializeRequestParams, InitializeResult, RequestId, RpcError,
1212
};
13+
use crate::utils::AbortTaskOnDrop;
1314
use async_trait::async_trait;
1415
use futures::future::try_join_all;
1516
use futures::{StreamExt, TryFutureExt};
1617
#[cfg(feature = "hyper-server")]
1718
use rust_mcp_transport::SessionId;
1819
use rust_mcp_transport::{IoStream, TransportDispatcher};
1920
use std::collections::HashMap;
20-
use std::sync::{Arc, RwLock};
21+
use std::sync::Arc;
2122
use std::time::Duration;
2223
use tokio::io::AsyncWriteExt;
2324
use tokio::sync::{oneshot, watch};
@@ -41,8 +42,6 @@ pub struct ServerRuntime {
4142
handler: Arc<dyn McpServerHandler>,
4243
// Information about the server
4344
server_details: Arc<InitializeResult>,
44-
// Details about the connected client
45-
client_details: Arc<RwLock<Option<InitializeRequestParams>>>,
4645
#[cfg(feature = "hyper-server")]
4746
session_id: Option<SessionId>,
4847
transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
@@ -123,12 +122,7 @@ impl McpServer for ServerRuntime {
123122

124123
/// Returns the client information if available, after successful initialization , otherwise returns None
125124
fn client_info(&self) -> Option<InitializeRequestParams> {
126-
if let Ok(details) = self.client_details.read() {
127-
details.clone()
128-
} else {
129-
// Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
130-
None
131-
}
125+
self.client_details_rx.borrow().clone()
132126
}
133127

134128
/// Main runtime loop, processes incoming messages and handles requests
@@ -404,6 +398,11 @@ impl ServerRuntime {
404398
.await?
405399
.abort_handle();
406400

401+
// ensure keep_alive task will be aborted
402+
let _abort_guard = AbortTaskOnDrop {
403+
handle: abort_alive_task,
404+
};
405+
407406
// in case there is a payload, we consume it by transport to get processed
408407
if let Some(payload) = payload {
409408
transport.consume_string_payload(&payload).await?;
@@ -439,13 +438,11 @@ impl ServerRuntime {
439438
}
440439
// close the stream after all messages are sent, unless it is a standalone stream
441440
if !stream_id.eq(DEFAULT_STREAM_ID){
442-
abort_alive_task.abort();
443441
return Ok(());
444442
}
445443
}
446444
_ = &mut disconnect_rx => {
447445
self.remove_transport(stream_id).await?;
448-
abort_alive_task.abort();
449446
// Disconnection detected by keep-alive task
450447
return Err(SdkError::connection_closed().into());
451448

@@ -469,7 +466,6 @@ impl ServerRuntime {
469466
watch::channel::<Option<InitializeRequestParams>>(None);
470467
Self {
471468
server_details,
472-
client_details: Arc::new(RwLock::new(None)),
473469
handler,
474470
session_id: Some(session_id),
475471
transport_map: tokio::sync::RwLock::new(HashMap::new()),
@@ -495,7 +491,6 @@ impl ServerRuntime {
495491
watch::channel::<Option<InitializeRequestParams>>(None);
496492
Self {
497493
server_details: Arc::new(server_details),
498-
client_details: Arc::new(RwLock::new(None)),
499494
handler,
500495
#[cfg(feature = "hyper-server")]
501496
session_id: None,

crates/rust-mcp-sdk/src/utils.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,23 @@ use crate::error::{McpSdkError, SdkResult};
44
use crate::schema::ProtocolVersion;
55
use std::cmp::Ordering;
66

7+
/// A guard type that automatically aborts a Tokio task when dropped.
8+
///
9+
/// This ensures that the associated task does not outlive the scope
10+
/// of this struct, preventing runaway or leaked background tasks.
11+
///
12+
pub struct AbortTaskOnDrop {
13+
/// The handle used to abort the spawned Tokio task.
14+
pub handle: tokio::task::AbortHandle,
15+
}
16+
17+
impl Drop for AbortTaskOnDrop {
18+
fn drop(&mut self) {
19+
// Automatically abort the associated task when this guard is dropped.
20+
self.handle.abort();
21+
}
22+
}
23+
724
/// Formats an assertion error message for unsupported capabilities.
825
///
926
/// Constructs a string describing that a specific entity (e.g., server or client) lacks

0 commit comments

Comments
 (0)