Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 69 additions & 16 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@ use crossterm::{
terminal::{disable_raw_mode, enable_raw_mode, EnterAlternateScreen, LeaveAlternateScreen},
};
use logging::{LogBuffer, LogBufferLayer};
use mcp::McpClient;
use mcp::{McpClient, ResponseMessage};
use ratatui::{backend::CrosstermBackend, Terminal};
use std::io;
use std::sync::Arc;
use tracing::Level;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
Expand Down Expand Up @@ -78,8 +79,9 @@ async fn run_tui(
.await
.context("Failed to initialize MCP client")?;

let client = Arc::new(client);
let mut app = App::new(debug_mode);
let res = run_tui_loop(&mut terminal, &mut app, &client, log_buffer).await;
let res = run_tui_loop(&mut terminal, &mut app, client.clone(), log_buffer).await;

disable_raw_mode()?;
execute!(
Expand All @@ -97,29 +99,80 @@ async fn run_tui(
async fn run_tui_loop(
terminal: &mut Terminal<CrosstermBackend<io::Stdout>>,
app: &mut App,
client: &McpClient,
client: Arc<McpClient>,
log_buffer: LogBuffer,
) -> Result<()> {
app.load_data(client).await?;
app.load_data(&client).await?;

loop {
// Update logs in the background
app.update_logs(client).await;
app.update_logs(&client).await;

// Update debug logs from buffer
app.update_debug_logs(log_buffer.get_all());

// Poll for tool call result (tool call runs in a spawned task so we can receive elicitation mid-call)
if let Some(rx) = app.tool_call_pending_rx.as_mut() {
match rx.try_recv() {
Ok((tool_name, result)) => app.apply_pending_tool_result(tool_name, result),
Err(tokio::sync::mpsc::error::TryRecvError::Disconnected) => {
app.tool_call_pending_rx = None;
}
Err(tokio::sync::mpsc::error::TryRecvError::Empty) => {}
}
}

// Poll for server-originated requests (e.g. elicitation/create)
if let Some(ResponseMessage::Notification(req)) = client.try_recv_server_message().await {
if req.method == "elicitation/create" {
let id = req.id.clone().unwrap_or(serde_json::Value::Null);
if let Err(e) = app.start_elicitation(id, req.params, &client).await {
app.error_message = Some(format!("Elicitation error: {}", e));
}
}
}

terminal.draw(|f| render_ui(f, app))?;

if event::poll(std::time::Duration::from_millis(100))? {
if let Event::Key(key) = event::read()? {
if key.kind == KeyEventKind::Press {
if app.tool_call_input_mode {
// Handle tool call input mode
// Elicitation first: server may send it during tool/prompt form; user must respond
if app.elicitation_input_mode {
match key.code {
KeyCode::Esc => app.cancel_elicitation(&client).await,
KeyCode::Enter => app.execute_elicitation_accept(&client).await,
KeyCode::Tab => {
if key.modifiers.contains(KeyModifiers::SHIFT) {
app.previous_input_field();
} else {
app.next_input_field();
}
}
KeyCode::BackTab => app.previous_input_field(),
KeyCode::Backspace => app.delete_current_input(),
KeyCode::Up => app.scroll_tool_input_up(),
KeyCode::Down => app.scroll_tool_input_down(),
KeyCode::Char(c) => {
// Ctrl+D = decline (so plain 'd'/'D' can be typed in fields)
if (c == 'd' || c == 'D')
&& key.modifiers.contains(KeyModifiers::CONTROL)
{
app.decline_elicitation(&client).await
} else {
app.update_current_input(c)
}
}
_ => {}
}
} else if app.tool_call_input_mode {
// Handle tool call input mode (or waiting for response)
match key.code {
KeyCode::Esc => app.cancel_tool_call(),
KeyCode::Enter => {
app.execute_tool_call(client).await;
if app.tool_call_pending_rx.is_none() {
app.execute_tool_call(client.clone());
}
}
KeyCode::Tab => {
if key.modifiers.contains(KeyModifiers::SHIFT) {
Expand All @@ -140,7 +193,7 @@ async fn run_tui_loop(
match key.code {
KeyCode::Esc => app.cancel_prompt_input(),
KeyCode::Enter => {
app.execute_prompt_get(client).await;
app.execute_prompt_get(&client).await;
}
KeyCode::Tab => {
if key.modifiers.contains(KeyModifiers::SHIFT) {
Expand All @@ -163,7 +216,7 @@ async fn run_tui_loop(
KeyCode::Char('c') | KeyCode::Char('C') => match app.current_tab {
tui::Tab::Tools => app.start_tool_call(),
tui::Tab::Prompts => app.start_prompt_get(),
tui::Tab::Resources => app.read_resource(client).await,
tui::Tab::Resources => app.read_resource(&client).await,
_ => {}
},
KeyCode::Down => app.next_item(),
Expand All @@ -178,32 +231,32 @@ async fn run_tui_loop(
KeyCode::Char('c') | KeyCode::Char('C') => match app.current_tab {
tui::Tab::Tools => app.start_tool_call(),
tui::Tab::Prompts => app.start_prompt_get(),
tui::Tab::Resources => app.read_resource(client).await,
tui::Tab::Resources => app.read_resource(&client).await,
_ => {}
},
KeyCode::Tab => {
app.current_tab = app.current_tab.next(app.debug_mode);
app.load_data(client).await?;
app.load_data(&client).await?;
}
KeyCode::BackTab => {
app.current_tab = app.current_tab.previous(app.debug_mode);
app.load_data(client).await?;
app.load_data(&client).await?;
}
KeyCode::Left => {
app.current_tab = app.current_tab.previous(app.debug_mode);
app.load_data(client).await?;
app.load_data(&client).await?;
}
KeyCode::Right => {
app.current_tab = app.current_tab.next(app.debug_mode);
app.load_data(client).await?;
app.load_data(&client).await?;
}
KeyCode::Down => app.next_item(),
KeyCode::Up => app.previous_item(),
KeyCode::PageDown => app.page_down(),
KeyCode::PageUp => app.page_up(),
KeyCode::Enter => app.show_detail(),
KeyCode::Char('r') | KeyCode::Char('R') => {
app.load_data(client).await?;
app.load_data(&client).await?;
}
KeyCode::Char('e') | KeyCode::Char('E') => {
app.scroll_to_bottom();
Expand Down
39 changes: 30 additions & 9 deletions src/mcp/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,17 @@ pub struct McpClient {
child: Arc<Mutex<Child>>,
stdin: Arc<Mutex<ChildStdin>>,
request_id: AtomicI64,
#[allow(dead_code)]
#[allow(dead_code)] // held to keep channel alive for read_loop
response_tx: mpsc::UnboundedSender<ResponseMessage>,
#[allow(dead_code)]
response_rx: Arc<Mutex<mpsc::UnboundedReceiver<ResponseMessage>>>,
pending_requests: Arc<Mutex<HashMap<i64, oneshot::Sender<JsonRpcResponse>>>>,
server_info: Arc<Mutex<Option<InitializeResult>>>,
log_rx: Arc<Mutex<mpsc::UnboundedReceiver<String>>>,
}

enum ResponseMessage {
pub enum ResponseMessage {
#[allow(dead_code)]
Response(JsonRpcResponse),
#[allow(dead_code)]
Notification(JsonRpcRequest),
}

Expand Down Expand Up @@ -97,7 +95,12 @@ impl McpClient {

debug!("Received: {}", trimmed);

if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(trimmed) {
// Try JsonRpcRequest first: server-originated requests (e.g. elicitation/create)
// have "method" and would incorrectly parse as JsonRpcResponse (extra fields
// ignored), then be mistaken for a response to our pending request.
if let Ok(request) = serde_json::from_str::<JsonRpcRequest>(trimmed) {
let _ = response_tx.send(ResponseMessage::Notification(request));
} else if let Ok(response) = serde_json::from_str::<JsonRpcResponse>(trimmed) {
if let Value::Number(id) = &response.id {
if let Some(id) = id.as_i64() {
let mut pending = pending_requests.lock().await;
Expand All @@ -108,9 +111,6 @@ impl McpClient {
}
}
let _ = response_tx.send(ResponseMessage::Response(response));
} else if let Ok(notification) = serde_json::from_str::<JsonRpcRequest>(trimmed)
{
let _ = response_tx.send(ResponseMessage::Notification(notification));
} else {
warn!("Failed to parse message: {}", trimmed);
}
Expand Down Expand Up @@ -204,10 +204,11 @@ impl McpClient {

pub async fn initialize(&self) -> Result<InitializeResult> {
let params = InitializeParams {
protocol_version: "2024-11-05".to_string(),
protocol_version: "2025-11-25".to_string(),
capabilities: ClientCapabilities {
roots: None,
sampling: None,
elicitation: Some(ElicitationCapability::default()),
},
client_info: Implementation {
name: env!("CARGO_PKG_NAME").to_string(),
Expand Down Expand Up @@ -293,6 +294,25 @@ impl McpClient {
logs
}

/// Non-blocking receive of server-originated requests or notifications (e.g. elicitation/create).
pub async fn try_recv_server_message(&self) -> Option<ResponseMessage> {
let mut rx = self.response_rx.lock().await;
rx.try_recv().ok()
}

/// Send a JSON-RPC response to the server (e.g. in reply to elicitation/create).
pub async fn send_response(&self, response: JsonRpcResponse) -> Result<()> {
let json = serde_json::to_string(&response)?;
debug!("Sending response: {}", json);

let mut stdin = self.stdin.lock().await;
stdin.write_all(json.as_bytes()).await?;
stdin.write_all(b"\n").await?;
stdin.flush().await?;

Ok(())
}

pub async fn shutdown(&self) -> Result<()> {
let _ = self.child.lock().await.kill().await;
Ok(())
Expand Down Expand Up @@ -387,6 +407,7 @@ mod tests {
capabilities: ClientCapabilities {
roots: None,
sampling: None,
elicitation: None,
},
client_info: Implementation {
name: env!("CARGO_PKG_NAME").to_string(),
Expand Down
2 changes: 1 addition & 1 deletion src/mcp/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
pub mod client;
pub mod protocol;

pub use client::McpClient;
pub use client::{McpClient, ResponseMessage};
24 changes: 24 additions & 0 deletions src/mcp/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,33 @@ pub struct InitializeParams {
pub client_info: Implementation,
}

#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct ElicitationCapability {}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientCapabilities {
#[serde(skip_serializing_if = "Option::is_none")]
pub roots: Option<RootsCapability>,
#[serde(skip_serializing_if = "Option::is_none")]
pub sampling: Option<HashMap<String, Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub elicitation: Option<ElicitationCapability>,
}

/// Params for server-originated `elicitation/create` request (form mode).
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ElicitationCreateParams {
#[serde(default)]
pub mode: Option<String>,
pub message: String,
#[serde(rename = "requestedSchema")]
#[serde(skip_serializing_if = "Option::is_none")]
pub requested_schema: Option<Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
#[serde(rename = "elicitationId")]
#[serde(skip_serializing_if = "Option::is_none")]
pub elicitation_id: Option<String>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -358,6 +379,7 @@ mod tests {
capabilities: ClientCapabilities {
roots: Some(RootsCapability { list_changed: true }),
sampling: None,
elicitation: None,
},
client_info: Implementation {
name: "test_client".to_string(),
Expand All @@ -381,6 +403,7 @@ mod tests {
capabilities: ClientCapabilities {
roots: None,
sampling: None,
elicitation: None,
},
client_info: Implementation {
name: "test".to_string(),
Expand Down Expand Up @@ -641,6 +664,7 @@ mod tests {
capabilities: ClientCapabilities {
roots: None,
sampling: None,
elicitation: None,
},
client_info: Implementation {
name: "test".to_string(),
Expand Down
Loading