1
1
pub mod mcp_server_runtime;
2
2
pub mod mcp_server_runtime_core;
3
+ use crate :: error:: SdkResult ;
4
+ use crate :: mcp_traits:: mcp_handler:: McpServerHandler ;
5
+ use crate :: mcp_traits:: mcp_server:: McpServer ;
3
6
use crate :: schema:: {
4
7
schema_utils:: {
5
- ClientMessage , ClientMessages , FromMessage , MessageFromServer , SdkError , ServerMessage ,
6
- ServerMessages ,
8
+ ClientMessage , ClientMessages , FromMessage , McpMessage , MessageFromServer , SdkError ,
9
+ ServerMessage , ServerMessages ,
7
10
} ,
8
11
InitializeRequestParams , InitializeResult , RequestId , RpcError ,
9
12
} ;
10
-
11
13
use async_trait:: async_trait;
12
14
use futures:: future:: try_join_all;
13
15
use futures:: { StreamExt , TryFutureExt } ;
14
-
16
+ #[ cfg( feature = "hyper-server" ) ]
17
+ use rust_mcp_transport:: SessionId ;
15
18
use rust_mcp_transport:: { IoStream , TransportDispatcher } ;
16
-
17
19
use std:: collections:: HashMap ;
18
20
use std:: sync:: { Arc , RwLock } ;
19
21
use std:: time:: Duration ;
20
22
use tokio:: io:: AsyncWriteExt ;
21
- use tokio:: sync:: oneshot;
23
+ use tokio:: sync:: { oneshot, watch } ;
22
24
23
- use crate :: error:: SdkResult ;
24
- use crate :: mcp_traits:: mcp_handler:: McpServerHandler ;
25
- use crate :: mcp_traits:: mcp_server:: McpServer ;
26
- #[ cfg( feature = "hyper-server" ) ]
27
- use rust_mcp_transport:: SessionId ;
28
25
pub const DEFAULT_STREAM_ID : & str = "STANDALONE-STREAM" ;
29
26
30
27
// Define a type alias for the TransportDispatcher trait object
@@ -49,21 +46,32 @@ pub struct ServerRuntime {
49
46
#[ cfg( feature = "hyper-server" ) ]
50
47
session_id : Option < SessionId > ,
51
48
transport_map : tokio:: sync:: RwLock < HashMap < String , TransportType > > ,
49
+ client_details_tx : watch:: Sender < Option < InitializeRequestParams > > ,
50
+ client_details_rx : watch:: Receiver < Option < InitializeRequestParams > > ,
52
51
}
53
52
54
53
#[ async_trait]
55
54
impl McpServer for ServerRuntime {
56
55
/// Set the client details, storing them in client_details
57
- fn set_client_details ( & self , client_details : InitializeRequestParams ) -> SdkResult < ( ) > {
58
- match self . client_details . write ( ) {
59
- Ok ( mut details) => {
60
- * details = Some ( client_details) ;
61
- Ok ( ( ) )
56
+ async fn set_client_details ( & self , client_details : InitializeRequestParams ) -> SdkResult < ( ) > {
57
+ self . handler . on_server_started ( self ) . await ;
58
+
59
+ self . client_details_tx
60
+ . send ( Some ( client_details) )
61
+ . map_err ( |_| {
62
+ RpcError :: internal_error ( )
63
+ . with_message ( "Failed to set client details" . to_string ( ) )
64
+ . into ( )
65
+ } )
66
+ }
67
+
68
+ async fn wait_for_initialization ( & self ) {
69
+ loop {
70
+ if self . client_details_rx . borrow ( ) . is_some ( ) {
71
+ return ;
62
72
}
63
- // Failed to acquire read lock, likely due to PoisonError from a thread panic. Returning None.
64
- Err ( _) => Err ( RpcError :: internal_error ( )
65
- . with_message ( "Internal Error: Failed to acquire write lock." . to_string ( ) )
66
- . into ( ) ) ,
73
+ let mut rx = self . client_details_rx . clone ( ) ;
74
+ rx. changed ( ) . await . ok ( ) ;
67
75
}
68
76
}
69
77
@@ -79,7 +87,19 @@ impl McpServer for ServerRuntime {
79
87
. with_message ( "transport stream does not exists or is closed!" . to_string ( ) ) ,
80
88
) ?;
81
89
82
- let mcp_message = ServerMessage :: from_message ( message, request_id) ?;
90
+ // generate a new request_id for request messages
91
+ let outgoing_request_id = if message. is_request ( ) {
92
+ match request_id {
93
+ Some ( _) => Err ( RpcError :: internal_error ( ) . with_message (
94
+ "request_id should not have a value when sending a new request" . to_string ( ) ,
95
+ ) ) ,
96
+ None => Ok ( self . next_request_id ( transport) . await ) ,
97
+ }
98
+ } else {
99
+ Ok ( request_id)
100
+ } ?;
101
+
102
+ let mcp_message = ServerMessage :: from_message ( message, outgoing_request_id) ?;
83
103
transport
84
104
. send_message ( ServerMessages :: Single ( mcp_message) , request_timeout)
85
105
. map_err ( |err| err. into ( ) )
@@ -130,8 +150,6 @@ impl McpServer for ServerRuntime {
130
150
131
151
let mut stream = transport. start ( ) . await ?;
132
152
133
- self . handler . on_server_started ( self ) . await ;
134
-
135
153
// Process incoming messages from the client
136
154
while let Some ( mcp_messages) = stream. next ( ) . await {
137
155
match mcp_messages {
@@ -207,6 +225,25 @@ impl ServerRuntime {
207
225
Ok ( ( ) )
208
226
}
209
227
228
+ pub ( crate ) async fn next_request_id (
229
+ & self ,
230
+ transport : & Arc <
231
+ dyn TransportDispatcher <
232
+ ClientMessages ,
233
+ MessageFromServer ,
234
+ ClientMessage ,
235
+ ServerMessages ,
236
+ ServerMessage ,
237
+ > ,
238
+ > ,
239
+ ) -> Option < RequestId > {
240
+ let message_sender = transport. message_sender ( ) ;
241
+ let guard = message_sender. read ( ) . await ;
242
+ guard
243
+ . as_ref ( )
244
+ . map ( |dispatcher| dispatcher. next_request_id ( ) )
245
+ }
246
+
210
247
pub ( crate ) async fn handle_message (
211
248
& self ,
212
249
message : ClientMessage ,
@@ -416,12 +453,16 @@ impl ServerRuntime {
416
453
handler : Arc < dyn McpServerHandler > ,
417
454
session_id : SessionId ,
418
455
) -> Self {
456
+ let ( client_details_tx, client_details_rx) =
457
+ watch:: channel :: < Option < InitializeRequestParams > > ( None ) ;
419
458
Self {
420
459
server_details,
421
460
client_details : Arc :: new ( RwLock :: new ( None ) ) ,
422
461
handler,
423
462
session_id : Some ( session_id) ,
424
463
transport_map : tokio:: sync:: RwLock :: new ( HashMap :: new ( ) ) ,
464
+ client_details_tx,
465
+ client_details_rx,
425
466
}
426
467
}
427
468
@@ -438,13 +479,17 @@ impl ServerRuntime {
438
479
) -> Self {
439
480
let mut map: HashMap < String , TransportType > = HashMap :: new ( ) ;
440
481
map. insert ( DEFAULT_STREAM_ID . to_string ( ) , Arc :: new ( transport) ) ;
482
+ let ( client_details_tx, client_details_rx) =
483
+ watch:: channel :: < Option < InitializeRequestParams > > ( None ) ;
441
484
Self {
442
485
server_details : Arc :: new ( server_details) ,
443
486
client_details : Arc :: new ( RwLock :: new ( None ) ) ,
444
487
handler,
445
488
#[ cfg( feature = "hyper-server" ) ]
446
489
session_id : None ,
447
490
transport_map : tokio:: sync:: RwLock :: new ( map) ,
491
+ client_details_tx,
492
+ client_details_rx,
448
493
}
449
494
}
450
495
}
0 commit comments