Skip to content

Commit 5f9a966

Browse files
authored
fix: ensure server-initiated requests include a valid request_id (#80)
1 parent 1457cff commit 5f9a966

File tree

4 files changed

+73
-24
lines changed

4 files changed

+73
-24
lines changed

crates/rust-mcp-sdk/src/mcp_handlers/mcp_server_handler.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ pub trait ServerHandler: Send + Sync + 'static {
5151

5252
runtime
5353
.set_client_details(initialize_request.params.clone())
54+
.await
5455
.map_err(|err| RpcError::internal_error().with_message(format!("{err}")))?;
5556

5657
Ok(server_info)

crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime.rs

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,27 @@
11
pub mod mcp_server_runtime;
22
pub mod mcp_server_runtime_core;
3+
use crate::error::SdkResult;
4+
use crate::mcp_traits::mcp_handler::McpServerHandler;
5+
use crate::mcp_traits::mcp_server::McpServer;
36
use crate::schema::{
47
schema_utils::{
5-
ClientMessage, ClientMessages, FromMessage, MessageFromServer, SdkError, ServerMessage,
6-
ServerMessages,
8+
ClientMessage, ClientMessages, FromMessage, McpMessage, MessageFromServer, SdkError,
9+
ServerMessage, ServerMessages,
710
},
811
InitializeRequestParams, InitializeResult, RequestId, RpcError,
912
};
10-
1113
use async_trait::async_trait;
1214
use futures::future::try_join_all;
1315
use futures::{StreamExt, TryFutureExt};
14-
16+
#[cfg(feature = "hyper-server")]
17+
use rust_mcp_transport::SessionId;
1518
use rust_mcp_transport::{IoStream, TransportDispatcher};
16-
1719
use std::collections::HashMap;
1820
use std::sync::{Arc, RwLock};
1921
use std::time::Duration;
2022
use tokio::io::AsyncWriteExt;
21-
use tokio::sync::oneshot;
23+
use tokio::sync::{oneshot, watch};
2224

23-
use crate::error::SdkResult;
24-
use crate::mcp_traits::mcp_handler::McpServerHandler;
25-
use crate::mcp_traits::mcp_server::McpServer;
26-
#[cfg(feature = "hyper-server")]
27-
use rust_mcp_transport::SessionId;
2825
pub const DEFAULT_STREAM_ID: &str = "STANDALONE-STREAM";
2926

3027
// Define a type alias for the TransportDispatcher trait object
@@ -49,21 +46,32 @@ pub struct ServerRuntime {
4946
#[cfg(feature = "hyper-server")]
5047
session_id: Option<SessionId>,
5148
transport_map: tokio::sync::RwLock<HashMap<String, TransportType>>,
49+
client_details_tx: watch::Sender<Option<InitializeRequestParams>>,
50+
client_details_rx: watch::Receiver<Option<InitializeRequestParams>>,
5251
}
5352

5453
#[async_trait]
5554
impl McpServer for ServerRuntime {
5655
/// Set the client details, storing them in client_details
57-
fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
58-
match self.client_details.write() {
59-
Ok(mut details) => {
60-
*details = Some(client_details);
61-
Ok(())
56+
async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()> {
57+
self.handler.on_server_started(self).await;
58+
59+
self.client_details_tx
60+
.send(Some(client_details))
61+
.map_err(|_| {
62+
RpcError::internal_error()
63+
.with_message("Failed to set client details".to_string())
64+
.into()
65+
})
66+
}
67+
68+
async fn wait_for_initialization(&self) {
69+
loop {
70+
if self.client_details_rx.borrow().is_some() {
71+
return;
6272
}
63-
// Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
64-
Err(_) => Err(RpcError::internal_error()
65-
.with_message("Internal Error: Failed to acquire write lock.".to_string())
66-
.into()),
73+
let mut rx = self.client_details_rx.clone();
74+
rx.changed().await.ok();
6775
}
6876
}
6977

@@ -79,7 +87,19 @@ impl McpServer for ServerRuntime {
7987
.with_message("transport stream does not exists or is closed!".to_string()),
8088
)?;
8189

82-
let mcp_message = ServerMessage::from_message(message, request_id)?;
90+
// generate a new request_id for request messages
91+
let outgoing_request_id = if message.is_request() {
92+
match request_id {
93+
Some(_) => Err(RpcError::internal_error().with_message(
94+
"request_id should not have a value when sending a new request".to_string(),
95+
)),
96+
None => Ok(self.next_request_id(transport).await),
97+
}
98+
} else {
99+
Ok(request_id)
100+
}?;
101+
102+
let mcp_message = ServerMessage::from_message(message, outgoing_request_id)?;
83103
transport
84104
.send_message(ServerMessages::Single(mcp_message), request_timeout)
85105
.map_err(|err| err.into())
@@ -130,8 +150,6 @@ impl McpServer for ServerRuntime {
130150

131151
let mut stream = transport.start().await?;
132152

133-
self.handler.on_server_started(self).await;
134-
135153
// Process incoming messages from the client
136154
while let Some(mcp_messages) = stream.next().await {
137155
match mcp_messages {
@@ -207,6 +225,25 @@ impl ServerRuntime {
207225
Ok(())
208226
}
209227

228+
pub(crate) async fn next_request_id(
229+
&self,
230+
transport: &Arc<
231+
dyn TransportDispatcher<
232+
ClientMessages,
233+
MessageFromServer,
234+
ClientMessage,
235+
ServerMessages,
236+
ServerMessage,
237+
>,
238+
>,
239+
) -> Option<RequestId> {
240+
let message_sender = transport.message_sender();
241+
let guard = message_sender.read().await;
242+
guard
243+
.as_ref()
244+
.map(|dispatcher| dispatcher.next_request_id())
245+
}
246+
210247
pub(crate) async fn handle_message(
211248
&self,
212249
message: ClientMessage,
@@ -416,12 +453,16 @@ impl ServerRuntime {
416453
handler: Arc<dyn McpServerHandler>,
417454
session_id: SessionId,
418455
) -> Self {
456+
let (client_details_tx, client_details_rx) =
457+
watch::channel::<Option<InitializeRequestParams>>(None);
419458
Self {
420459
server_details,
421460
client_details: Arc::new(RwLock::new(None)),
422461
handler,
423462
session_id: Some(session_id),
424463
transport_map: tokio::sync::RwLock::new(HashMap::new()),
464+
client_details_tx,
465+
client_details_rx,
425466
}
426467
}
427468

@@ -438,13 +479,17 @@ impl ServerRuntime {
438479
) -> Self {
439480
let mut map: HashMap<String, TransportType> = HashMap::new();
440481
map.insert(DEFAULT_STREAM_ID.to_string(), Arc::new(transport));
482+
let (client_details_tx, client_details_rx) =
483+
watch::channel::<Option<InitializeRequestParams>>(None);
441484
Self {
442485
server_details: Arc::new(server_details),
443486
client_details: Arc::new(RwLock::new(None)),
444487
handler,
445488
#[cfg(feature = "hyper-server")]
446489
session_id: None,
447490
transport_map: tokio::sync::RwLock::new(map),
491+
client_details_tx,
492+
client_details_rx,
448493
}
449494
}
450495
}

crates/rust-mcp-sdk/src/mcp_runtimes/server_runtime/mcp_server_runtime_core.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ impl McpServerHandler for RuntimeCoreInternalHandler<Box<dyn ServerHandlerCore>>
7676
// keep a copy of the InitializeRequestParams which includes client_info and capabilities
7777
runtime
7878
.set_client_details(initialize_request.params.clone())
79+
.await
7980
.map_err(|err| RpcError::internal_error().with_message(format!("{err}")))?;
8081
}
8182

crates/rust-mcp-sdk/src/mcp_traits/mcp_server.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@ use crate::{error::SdkResult, utils::format_assertion_message};
2323
#[async_trait]
2424
pub trait McpServer: Sync + Send {
2525
async fn start(&self) -> SdkResult<()>;
26-
fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
26+
async fn set_client_details(&self, client_details: InitializeRequestParams) -> SdkResult<()>;
2727
fn server_info(&self) -> &InitializeResult;
2828
fn client_info(&self) -> Option<InitializeRequestParams>;
2929

30+
async fn wait_for_initialization(&self);
31+
3032
#[deprecated(since = "0.2.0", note = "Use `client_info()` instead.")]
3133
fn get_client_info(&self) -> Option<InitializeRequestParams> {
3234
self.client_info()

0 commit comments

Comments
 (0)